diff --git a/tests/test_video_mae.py b/tests/test_video_mae.py index 157c165cb..1240319e9 100644 --- a/tests/test_video_mae.py +++ b/tests/test_video_mae.py @@ -14,6 +14,7 @@ # limitations under the License. +import os import time from unittest import TestCase @@ -24,7 +25,12 @@ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor -LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 17.544198036193848 +if os.environ.get("GAUDI2_CI", "0") == "1": + # Gaudi2 CI baselines + LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 17.544198036193848 +else: + # Gaudi2 CI baselines + LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 61.953186988830566 MODEL_NAME = "MCG-NJU/videomae-base-finetuned-kinetics"