Skip to content

Commit

Permalink
Ruff formatting for test
Browse files Browse the repository at this point in the history
  • Loading branch information
pi314ever committed Jun 13, 2024
1 parent b086956 commit 9d87649
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions tests/test_video_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor


LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 17.544198036193848
MODEL_NAME = "MCG-NJU/videomae-base-finetuned-kinetics"

Expand Down Expand Up @@ -55,9 +56,7 @@ def outputs_cpu(request):

@pytest.fixture(autouse=True, scope="class")
def model_hpu(request):
request.cls.model_hpu = VideoMAEForVideoClassification.from_pretrained(
MODEL_NAME
).to("hpu")
request.cls.model_hpu = VideoMAEForVideoClassification.from_pretrained(MODEL_NAME).to("hpu")
request.cls.model_hpu_graph = ht.hpu.wrap_in_hpu_graph(request.cls.model_hpu)


Expand All @@ -83,11 +82,7 @@ def test_inference_default(self):
self.outputs_hpu_default.logits.cpu().topk(10).indices,
)
)
self.assertTrue(
torch.allclose(
self.outputs_cpu.logits, self.outputs_hpu_default.logits, atol=5e-3
)
)
self.assertTrue(torch.allclose(self.outputs_cpu.logits, self.outputs_hpu_default.logits, atol=5e-3))

def test_inference_bf16(self):
"""
Expand Down

0 comments on commit 9d87649

Please sign in to comment.