Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sdpa to ViT [follow up of #29325] #30555

Merged
merged 1 commit into from May 16, 2024
Merged

Conversation

hyenal
Copy link
Contributor

@hyenal hyenal commented Apr 29, 2024

What does this PR do?

Adding support for SDPA to ViT.

This PR is a followup of #29325 , most (all) of the work was done by @lyaronskaya.

Fixes #28005.

This PR also include a minor fix in the SDPA doc checks.

I am currently running RUN_SLOW=1 pytest tests/models/ on a GPU and will report the result in the thread

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. @ArthurZucker and @fxmarty have already reviewed this PR and @amyeroberts may be interested as she commented on the original PR

@hyenal
Copy link
Contributor Author

hyenal commented Apr 30, 2024

To make things faster I tried running on a GPU

RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/ tests/models/vision_encoder_decoder/ tests/models/vision_text_dual_encoder/ tests/models/vit/ tests/models/vit_mae/ tests/models/vit_msn/ tests/models/yolos/

So far I am getting a few fails, some (OOM) unrelated to this PR

====================================================================================== short test summary info =======================================================================================
FAILED tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py::FlaxViT2GPT2EncoderDecoderModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py::TFViT2GPT2EncoderDecoderModelTest::test_pt_tf_model_equivalence - AssertionError: False is not true : outputs.encoder_attentions_0: `pt_outputs` should a tensor when `tf_outputs` is
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::DeiT2RobertaModelTest::test_encoder_decoder_model_output_attentions - AttributeError: 'NoneType' object has no attribute 'shape'
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::ViT2BertModelTest::test_encoder_decoder_model_output_attentions - ValueError: You have to specify pixel_values
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::ViT2TrOCR::test_encoder_decoder_model_output_attentions - ValueError: You have to specify pixel_values
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::TrOCRModelIntegrationTest::test_inference_handwritten - torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::TrOCRModelIntegrationTest::test_inference_printed - torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::ViT2GPT2ModelIntegrationTest::test_inference_coco_en - torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::DonutModelIntegrationTest::test_inference_cordv2 - torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::DonutModelIntegrationTest::test_inference_docvqa - torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 150.00 MiB. GPU
FAILED tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::DonutModelIntegrationTest::test_inference_rvlcdip - torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU
FAILED tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py::FlaxViTBertModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py::FlaxCLIPVisionBertModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::ViTBertModelTest::test_pt_flax_equivalence - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::ViTBertModelTest::test_vision_text_output_attention - AttributeError: 'NoneType' object has no attribute 'shape'
FAILED tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::DeiTRobertaModelTest::test_vision_text_output_attention - AttributeError: 'NoneType' object has no attribute 'shape'
FAILED tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::CLIPVisionBertModelTest::test_pt_flax_equivalence - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED tests/models/yolos/test_image_processing_yolos.py::YolosImageProcessingTest::test_batched_coco_detection_annotations - ImportError: Pycocotools is not installed in your environment.
FAILED tests/models/yolos/test_modeling_yolos.py::YolosModelTest::test_attention_outputs - AttributeError: 'NoneType' object has no attribute 'shape'
================================================================ 19 failed, 820 passed, 427 skipped, 95 warnings in 607.33s (0:10:07) ===============================================================

@amyeroberts
Copy link
Collaborator

Thanks for working on this and enabling this for our models, @hyenal!

We've literally just merged in a new feature which should help us run slow tests. To enable this, I've added the run-slow label to this PR. To trigger a run of the slow tests could you:

  • Rebase on main to include General PR slow CI #30540
  • Push an empty commit with the message: [run-slow] audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae

@hyenal
Copy link
Contributor Author

hyenal commented May 1, 2024

@amyeroberts I rebased and ran the pipeline as indicated. The last one should have failed (I know yolo and the encoder/decoder are not ready yet) so I am not sure there's something I did incorrectly

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts
Copy link
Collaborator

@hyenal It was just waiting for me to approve the run :) We don't run automatically for security reasons and to prevent running slow, heavy tests unnecessarily

@hyenal
Copy link
Contributor Author

hyenal commented May 1, 2024

Thanks you @amyeroberts I will fix the tests then and request a new SLOW run when things are fixed :)

@hyenal
Copy link
Contributor Author

hyenal commented May 2, 2024

@amyeroberts when you have some time could you run the latest slow run I pushed ? I fixed most of the issues but there are 3 failures (ViT2BertModelTest.test_real_model_save_load_from_pretrained , NougatModelIntegrationTest.test_forward_pass, NougatModelIntegrationTest.test_generation ) I did not manage to reproduce locally.

Is there any specific command I should run for these tests ?

@amyeroberts
Copy link
Collaborator

@hyenal Sure! I've approved the workflow run, which should trigger these tests. I don't think there should be anything special you need to run these. If you're unable to reproduce locally, and they're being run (not skipped) then it's likely just an env or runner issue and we can help try and debug that.

@hyenal
Copy link
Contributor Author

hyenal commented May 2, 2024

The MR is now ready, 3 slow tests are failing but I am unable to find the source of it (a precision error due to SDPA ?) if possible I would like to get some help on it.

To further check if everything is working I could push a slow-test pipeline commit to check that all changed models are working in slow mode ?

@hyenal
Copy link
Contributor Author

hyenal commented May 3, 2024

@amyeroberts I am afraid that I cannot find a direct link between this PR and the current failures:

  • tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::ViT2BertModelTest::test_real_model_save_load_from_pretrained: these tests also fails for me on main. It seems that some parameters are not properly initialised according to the stderr
Some weights of ViTModel were not initialized from the model checkpoint at hf-internal-testing/tiny-random-vit and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertLMHeadModel were not initialized from the model checkpoint at hf-internal-testing/tiny-bert and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.self.value.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  • Nougat tests: Nougat is using Swin which is not part of this PR. The currently failing tests are also failing on main on my side.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

If tests are failing on main and are unrelated, it's OK for us to ignore. At the moment, the test suite changes need to be updated to keep in-line with the rest of the library and to avoid erroneously skipping tests:

  • Remove has_attentions = False
  • Make all tests use eager mode by default
  • Set eager mode in the config creation

@hyenal hyenal force-pushed the sdpa-vit branch 2 times, most recently from 80a3d20 to ae94544 Compare May 8, 2024 19:52
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the continued work on this!

Most comments are nits - in particular line splits which should be left.

Main comment is about having a self.attn_implementation argument to control the config creation

src/transformers/models/vit_mae/modeling_vit_mae.py Outdated Show resolved Hide resolved
src/transformers/models/vit_mae/modeling_vit_mae.py Outdated Show resolved Hide resolved
src/transformers/models/vit_msn/modeling_vit_msn.py Outdated Show resolved Hide resolved
src/transformers/models/vit_msn/modeling_vit_msn.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for making all the updates to the tests!

Last thing to do is add performance numbers for the models e.g. like here for Mistral. It's not necessary to run for all of the models (although this would be great!) but getting numbers for deit, vit, vitmae and yolos should be done as they're quite popular.

@hyenal
Copy link
Contributor Author

hyenal commented May 13, 2024

Last thing to do is add performance numbers for the models e.g. like here for Mistral. It's not necessary to run for all of the models (although this would be great!) but getting numbers for deit, vit, vitmae and yolos should be done as they're quite popular.

@amyeroberts that can be done! If you have any script that I could use so that we keep the same format for the images that would be great!

Also do you mind resolving comments that are left open ? Just to confirm we agree :)

@amyeroberts
Copy link
Collaborator

Also do you mind resolving comments that are left open ? Just to confirm we agree :)

@hyenal Sure! I think I've resolved all of them. Let me know if there's any I missed.

@amyeroberts that can be done! If you have any script that I could use so that we keep the same format for the images that would be great!

I don't have a script to hand, unfortunately. In terms of measuring the speed ups, it's OK to use different images/formats across the different models, as long as the settings for e.g. ViT are consistent.

@hyenal
Copy link
Contributor Author

hyenal commented May 13, 2024

I copied the style of #30390, let me know if the docs is okay.

Code for reproducibility
from collections import defaultdict
from time import perf_counter_ns

import numpy as np
import pandas as pd
import requests
import torch
from PIL import Image
from tabulate import tabulate

BATCH_SIZES = [1, 2, 4, 8]
ATTN_IMPLEMENTATION = ["eager", "sdpa"]


def profile_ast(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
) -> int:
    import torch
    from datasets import load_dataset

    from transformers import ASTForAudioClassification, AutoFeatureExtractor

    dataset = load_dataset(
        "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
    )
    dataset = dataset.sort("id")
    sampling_rate = dataset.features["audio"].sampling_rate

    feature_extractor = AutoFeatureExtractor.from_pretrained(
        "MIT/ast-finetuned-audioset-10-10-0.4593",
        torch_dtype=dtype,
    )
    model = ASTForAudioClassification.from_pretrained(
        "MIT/ast-finetuned-audioset-10-10-0.4593",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )

    inputs = feature_extractor(
        dataset[0]["audio"]["array"],
        sampling_rate=sampling_rate,
        return_tensors="pt",
    )  # .to("cuda")
    inputs["input_values"] = inputs["input_values"].tile((batch_size, 1, 1))
    if use_cuda:
        inputs["input_values"] = inputs["input_values"].to("cuda")

    total_time = 0.0

    if use_cuda:
        model = model.to("cuda")
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6

    return int(total_time / n_trial)


def profile_deit(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
) -> int:
    from transformers import AutoImageProcessor, DeiTForImageClassification

    torch.manual_seed(3)
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)

    # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
    # so the head will be randomly initialized, hence the predictions will be random
    image_processor = AutoImageProcessor.from_pretrained(
        "facebook/deit-base-distilled-patch16-224",
        torch_dtype=dtype,
    )
    model = DeiTForImageClassification.from_pretrained(
        "facebook/deit-base-distilled-patch16-224",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )
    if use_cuda:
        model = model.to("cuda")

    inputs = image_processor(images=image, return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].tile((batch_size, 1, 1, 1))
    if use_cuda:
        inputs["pixel_values"] = inputs["pixel_values"].to("cuda")

    total_time = 0.0
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6
    return int(total_time / n_trial)


def profile_vit(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
):
    import torch
    from datasets import load_dataset

    from transformers import AutoImageProcessor, ViTForImageClassification

    dataset = load_dataset("huggingface/cats-image")
    image = dataset["test"]["image"][0]

    image_processor = AutoImageProcessor.from_pretrained(
        "google/vit-base-patch16-224",
        torch_dtype=dtype,
    )
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )
    if use_cuda:
        model = model.to("cuda")

    inputs = image_processor(image, return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].tile((batch_size, 1, 1, 1))
    if use_cuda:
        inputs["pixel_values"] = inputs["pixel_values"].to("cuda")

    total_time = 0.0
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6
    return int(total_time / n_trial)


def profile_vit_hybrid(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
):
    import torch
    from datasets import load_dataset

    from transformers import AutoImageProcessor, ViTHybridForImageClassification

    dataset = load_dataset("huggingface/cats-image")
    image = dataset["test"]["image"][0]

    image_processor = AutoImageProcessor.from_pretrained(
        "google/vit-hybrid-base-bit-384",
        torch_dtype=dtype,
    )
    model = ViTHybridForImageClassification.from_pretrained(
        "google/vit-hybrid-base-bit-384",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )
    if use_cuda:
        model = model.to("cuda")

    inputs = image_processor(image, return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].tile((batch_size, 1, 1, 1))
    if use_cuda:
        inputs["pixel_values"] = inputs["pixel_values"].to("cuda")

    total_time = 0.0
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6
    return int(total_time / n_trial)


def profile_vit_mae(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
):
    # Vit Mae
    from transformers import AutoImageProcessor, ViTMAEModel

    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)

    image_processor = AutoImageProcessor.from_pretrained(
        "facebook/vit-mae-base",
        torch_dtype=dtype,
    )
    model = ViTMAEModel.from_pretrained(
        "facebook/vit-mae-base",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )
    if use_cuda:
        model = model.to("cuda")

    inputs = image_processor(images=image, return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].tile((batch_size, 1, 1, 1))
    if use_cuda:
        inputs["pixel_values"] = inputs["pixel_values"].to("cuda")

    total_time = 0.0
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6
    return int(total_time / n_trial)


def profile_vit_msn(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
):
    from transformers import AutoImageProcessor, ViTMSNForImageClassification

    torch.manual_seed(2)
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)

    image_processor = AutoImageProcessor.from_pretrained(
        "facebook/vit-msn-base",
        torch_dtype=dtype,
    )
    model = ViTMSNForImageClassification.from_pretrained(
        "facebook/vit-msn-base",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )
    if use_cuda:
        model = model.to("cuda")

    inputs = image_processor(images=image, return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].tile((batch_size, 1, 1, 1))
    if use_cuda:
        inputs["pixel_values"] = inputs["pixel_values"].to("cuda")

    total_time = 0.0
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6
    return int(total_time / n_trial)


def profile_yolo(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
):
    from transformers import AutoImageProcessor, AutoModelForObjectDetection

    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = Image.open(requests.get(url, stream=True).raw)

    image_processor = AutoImageProcessor.from_pretrained(
        "hustvl/yolos-base",
        torch_dtype=dtype,
    )
    model = AutoModelForObjectDetection.from_pretrained(
        "hustvl/yolos-base",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )
    if use_cuda:
        model = model.to("cuda")

    inputs = image_processor(images=image, return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].tile((batch_size, 1, 1, 1))
    if use_cuda:
        inputs["pixel_values"] = inputs["pixel_values"].to("cuda")

    total_time = 0.0
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6
    return int(total_time / n_trial)


def profile_videomae(
    attn_implementation: str = "eager",
    n_trial: int = 10,
    batch_size: int = 1,
    use_cuda: bool = False,
    dtype=torch.float32,
):
    import av
    from huggingface_hub import hf_hub_download

    from transformers import AutoImageProcessor, VideoMAEForVideoClassification

    np.random.seed(0)

    def read_video_pyav(container, indices):
        """
        Decode the video with PyAV decoder.
        Args:
            container (`av.container.input.InputContainer`): PyAV container.
            indices (`List[int]`): List of frame indices to decode.
        Returns:
            result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        """
        frames = []
        container.seek(0)
        start_index = indices[0]
        end_index = indices[-1]
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break
            if i >= start_index and i in indices:
                frames.append(frame)
        return np.stack([x.to_ndarray(format="rgb24") for x in frames])

    def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
        """
        Sample a given number of frame indices from the video.
        Args:
            clip_len (`int`): Total number of frames to sample.
            frame_sample_rate (`int`): Sample every n-th frame.
            seg_len (`int`): Maximum allowed index of sample's last frame.
        Returns:
            indices (`List[int]`): List of sampled frame indices
        """
        converted_len = int(clip_len * frame_sample_rate)
        end_idx = np.random.randint(converted_len, seg_len)
        start_idx = end_idx - converted_len
        indices = np.linspace(start_idx, end_idx, num=clip_len)
        indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        return indices

    # video clip consists of 300 frames (10 seconds at 30 FPS)
    file_path = hf_hub_download(
        repo_id="nielsr/video-demo",
        filename="eating_spaghetti.mp4",
        repo_type="dataset",
    )
    container = av.open(file_path)

    # sample 16 frames
    indices = sample_frame_indices(
        clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames
    )
    video = read_video_pyav(container, indices)

    image_processor = AutoImageProcessor.from_pretrained(
        "MCG-NJU/videomae-base-finetuned-kinetics",
        torch_dtype=dtype,
    )
    model = VideoMAEForVideoClassification.from_pretrained(
        "MCG-NJU/videomae-base-finetuned-kinetics",
        attn_implementation=attn_implementation,
        torch_dtype=dtype,
    )
    if use_cuda:
        model = model.to("cuda")

    inputs = image_processor(list(video), return_tensors="pt")
    inputs["pixel_values"] = inputs["pixel_values"].tile((batch_size, 1, 1, 1, 1))
    if use_cuda:
        inputs["pixel_values"] = inputs["pixel_values"].to("cuda")

    total_time = 0.0
    for _ in range(n_trial):
        time_start = perf_counter_ns()
        with torch.no_grad():
            model(**inputs)
        time_end = perf_counter_ns()

        total_time += (time_end - time_start) / 1e6
    return int(total_time / n_trial)


def print_comparison(
    name: str, batch_sizes: list[int], time_eager: list[float], time_sdpa: list[float]
) -> None:
    df = pd.DataFrame(
        {
            "Batch size": batch_sizes,
            "Average inference time (ms), eager mode": time_eager,
            "Average inference time (ms), sdpa model": time_sdpa,
            "Speed up, Sdpa / Eager (x)": np.array(time_eager) / np.array(time_sdpa),
        }
    )
    print(f"Model: {name}")
    print(tabulate(df, headers=df.columns, showindex=False, tablefmt="github"))


MODELS = {
    "AST": profile_ast,
    "Deit": profile_deit,
    "ViT": profile_vit,
    "ViT Hybrid": profile_vit_hybrid,
    "ViT MAE": profile_vit_mae,
    "ViT MSN": profile_vit_msn,
    "Yolos": profile_yolo,
    "VideoMAE": profile_videomae
}

for model_name, profiler in MODELS.items():
    times = defaultdict(list)
    for attn_implementation in ATTN_IMPLEMENTATION:
        for b in BATCH_SIZES:
            times[attn_implementation].append(
                profiler(
                    attn_implementation=attn_implementation,
                    batch_size=b,
                    dtype=torch.float32,
                    use_cuda=True,
                )
            )

    print_comparison(model_name, BATCH_SIZES, times["eager"], times["sdpa"])

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on adding this and improving these models! It's great to have this feature added ❤️

The diff all looks good to me, only thing to tackle are the slow tests. For the different tests, it seems that SDPA is being selected in certain cases e.g. for the model parallelism for ViTMSN.

For the others, it's possible SDPA is still being selected, it's strange the difference in the logits. On last night's CI, which runs all the slow tests, all of the vit, nougat and vision encoder-decoder tests passed.

@hyenal
Copy link
Contributor Author

hyenal commented May 15, 2024

Is there anything left to do on this MR ? Since it has been approved I am wondering about next steps in order to merge :)

@amyeroberts
Copy link
Collaborator

@hyenal Only thing is the Vit2Bert vision encoder-decoder integration test. Agreed that this PR shouldn't have any effect on the nougat/donut tests and we can ignore those

@hyenal
Copy link
Contributor Author

hyenal commented May 15, 2024

@amyeroberts do you have any recent slow pipeline on main where ViT2Bert passed ? Using a A100 or my local machine (CPU), twice I got the same error as I have on this PR.
Steps to reproduce:

# Clone repository
pip install -e ".[dev]"
RUN_SLOW=1 pytest tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::ViT2BertModelTest::test_real_model_save_load_from_pretrained

I have tried to look for SLOW pipeline on main but all I can find are pending or cancelled pipelines

@amyeroberts
Copy link
Collaborator

@hyenal Let me dig into it and see 🕵️ It'll be tomorrow though, as I'm signing off soon

Squashed commits:
[24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits)
Squashed commits:
[08bd27e] [run-slow]vit_msn,vision_encoder_decoder
[ec96a8d] [run-slow]vit_msn
[ead817e] fix vit msn multi gpu
[d12cdc8] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[3fdbfa8] doc
[a3ff33e] finish implementation
[e20b7b7] Update test_modeling_common.py
[e290c58] Update test_modeling_flax_common.py
[d3af86f] comment
[ff7dd32] more comments
[59b1378] suggestion
[7e2ba6d] attn_implementation as attribute of the class
[fe66ab7] minor
[38642b5] Apply suggestions from code review

Accept comments

Co-authored-by: amyeroberts <[email protected]>
[22cde7d] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <[email protected]>
[48e137c] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <[email protected]>
[99f4c67] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <[email protected]>
[96cf20a] Update src/transformers/models/vit_msn/modeling_vit_msn.py

Co-authored-by: amyeroberts <[email protected]>
[c59377d] Update src/transformers/models/vit_mae/modeling_vit_mae.py

Co-authored-by: amyeroberts <[email protected]>
[b70a472] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py

Co-authored-by: amyeroberts <[email protected]>
[00c84d2] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[61f00eb] all tests are passing locally
[e9e0b82] vision encoder/decoder
[4d5076b] test-vision (+20 squashed commits)
Squashed commits:
[d1add8db9] yolo
[9fde65716] fix flax
[986566c28] minor
[ca2f21d1f] vit
[3333efd7a] easy models change
[ebfc214] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[b8b8603] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[48ecc7e] all tests are passing locally
[bff7fc3] minor
[62f8830] fix yolo and text_encoder tests
[1215075] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1064cae] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[b7f52ff] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[cffaa10] fix-copies
[ef6c511] test vit hybrid
[7d4ba86] vit hybrid
[66f9190] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1fcc0a0] fixes
[cfde6eb] fixup
[e77df1e] all except yolo end encoder decoder (+17 squashed commits)
Squashed commits:
[602913e] vit + vit_mae are working
[547f6c4] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/  passes
[61a97df] it s the complete opposite...
[aefab37] fix more tests
[71802a1] fix all torch tests
[40b12eb] encoder - decoder tests
[941552b] slow decorator where appropriate
[14d055d] has_attentions to yolo and msn
[3381fa1] add correct name
[e261316] repo consistency
[31c6d0c] fixup
[9d21427] minor fix
[11ed2e1] chore
[eca6644] add sdpa to vit-based models
[cffbf39] make fix-copies result
[6468319] fix style
[d324cd0] add sdpa for vit
@amyeroberts
Copy link
Collaborator

Hi @hyenal, I got yesterday's full slow model CI run here: https://github.com/huggingface/transformers/actions/runs/9089085470/job/24979852319

And good news - all of the failing tests: nougat, donut, vit2bert are failing there too 🥳

I'll merge now. Thanks for all the work and patience adding this impactful feature!

@amyeroberts amyeroberts merged commit 1c21f48 into huggingface:main May 16, 2024
23 checks passed
@hyenal
Copy link
Contributor Author

hyenal commented May 16, 2024

Thank you so much @amyeroberts!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Open to contribution: adding torch.nn.functional.scaled_dot_product_attention support for more architectures
4 participants