Skip to content

Commit

Permalink
add sdpa to ViT [follow up of #29325] (#30555)
Browse files Browse the repository at this point in the history
remove blank line (+1 squashed commit)
Squashed commits:
[24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits)
Squashed commits:
[08bd27e] [run-slow]vit_msn,vision_encoder_decoder
[ec96a8d] [run-slow]vit_msn
[ead817e] fix vit msn multi gpu
[d12cdc8] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[3fdbfa8] doc
[a3ff33e] finish implementation
[e20b7b7] Update test_modeling_common.py
[e290c58] Update test_modeling_flax_common.py
[d3af86f] comment
[ff7dd32] more comments
[59b1378] suggestion
[7e2ba6d] attn_implementation as attribute of the class
[fe66ab7] minor
[38642b5] Apply suggestions from code review

Accept comments

Co-authored-by: amyeroberts <[email protected]>
[22cde7d] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <[email protected]>
[48e137c] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <[email protected]>
[99f4c67] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <[email protected]>
[96cf20a] Update src/transformers/models/vit_msn/modeling_vit_msn.py

Co-authored-by: amyeroberts <[email protected]>
[c59377d] Update src/transformers/models/vit_mae/modeling_vit_mae.py

Co-authored-by: amyeroberts <[email protected]>
[b70a472] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py

Co-authored-by: amyeroberts <[email protected]>
[00c84d2] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[61f00eb] all tests are passing locally
[e9e0b82] vision encoder/decoder
[4d5076b] test-vision (+20 squashed commits)
Squashed commits:
[d1add8db9] yolo
[9fde65716] fix flax
[986566c28] minor
[ca2f21d1f] vit
[3333efd7a] easy models change
[ebfc214] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[b8b8603] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[48ecc7e] all tests are passing locally
[bff7fc3] minor
[62f8830] fix yolo and text_encoder tests
[1215075] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1064cae] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[b7f52ff] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[cffaa10] fix-copies
[ef6c511] test vit hybrid
[7d4ba86] vit hybrid
[66f9190] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1fcc0a0] fixes
[cfde6eb] fixup
[e77df1e] all except yolo end encoder decoder (+17 squashed commits)
Squashed commits:
[602913e] vit + vit_mae are working
[547f6c4] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/  passes
[61a97df] it s the complete opposite...
[aefab37] fix more tests
[71802a1] fix all torch tests
[40b12eb] encoder - decoder tests
[941552b] slow decorator where appropriate
[14d055d] has_attentions to yolo and msn
[3381fa1] add correct name
[e261316] repo consistency
[31c6d0c] fixup
[9d21427] minor fix
[11ed2e1] chore
[eca6644] add sdpa to vit-based models
[cffbf39] make fix-copies result
[6468319] fix style
[d324cd0] add sdpa for vit

Co-authored-by: Liubov Yaronskaya <[email protected]>
  • Loading branch information
hyenal and lyaronskaya committed May 16, 2024
1 parent 9fd606d commit 1c21f48
Show file tree
Hide file tree
Showing 34 changed files with 709 additions and 26 deletions.
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/audio-spectrogram-transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,34 @@ the authors compute the stats for a downstream dataset.
- Note that the AST needs a low learning rate (the authors use a 10 times smaller learning rate compared to their CNN model proposed in the
[PSLA paper](https://arxiv.org/abs/2102.01243)) and converges quickly, so please search for a suitable learning rate and learning rate scheduler for your task.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import ASTForAudioClassification
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MIT/ast-finetuned-audioset-10-10-0.4593` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 27 | 6 | 4.5 |
| 2 | 12 | 6 | 2 |
| 4 | 21 | 8 | 2.62 |
| 8 | 40 | 14 | 2.86 |

## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with the Audio Spectrogram Transformer.
Expand Down
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/deit.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,34 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The Tenso
*facebook/deit-base-patch16-384*. Note that one should use [`DeiTImageProcessor`] in order to
prepare images for the model.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import DeiTForImageClassification
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/deit-base-distilled-patch16-224` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 8 | 6 | 1.33 |
| 2 | 9 | 6 | 1.5 |
| 4 | 9 | 6 | 1.5 |
| 8 | 8 | 6 | 1.33 |

## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with DeiT.
Expand Down
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/videomae.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/MCG-NJU/VideoMAE).

## Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import VideoMAEForVideoClassification
model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MCG-NJU/videomae-base-finetuned-kinetics` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 37 | 10 | 3.7 |
| 2 | 24 | 18 | 1.33 |
| 4 | 43 | 32 | 1.34 |
| 8 | 84 | 60 | 1.4 |

## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with VideoMAE. If
Expand Down
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/vit.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,34 @@ who already converted the weights from JAX to PyTorch. Credits go to him!
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |

## Resources

Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer).
Expand Down
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/vit_hybrid.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,34 @@ substantially fewer computational resources to train.*
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
found [here](https://github.com/google-research/vision_transformer).

## Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import ViTHybridForImageClassification
model = ViTHybridForImageClassification.from_pretrained("google/vit-hybrid-base-bit-384", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-hybrid-base-bit-384` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 29 | 18 | 1.61 |
| 2 | 26 | 18 | 1.44 |
| 4 | 25 | 18 | 1.39 |
| 8 | 34 | 24 | 1.42 |

## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with ViT Hybrid.
Expand Down
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/vit_mae.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ consists of Transformer blocks) takes as input. Each mask token is a shared, lea
sin/cos position embeddings are added both to the input of the encoder and the decoder.
- For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/).

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import ViTMAEModel
model = ViTMAEModel.from_pretrained("facebook/vit-mae-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-mae-base` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 11 | 6 | 1.83 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |

## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with ViTMAE.
Expand Down
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/vit_msn.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,34 @@ use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMS
- MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K
labels when fine-tuned.

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import ViTMSNForImageClassification
model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-msn-base` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |

## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with ViT MSN.
Expand Down
28 changes: 28 additions & 0 deletions docs/source/en/model_doc/yolos.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ alt="drawing" width="600"/>

This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS).

## Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import AutoModelForObjectDetection
model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `hustvl/yolos-base` model, we saw the following speedups during inference.

| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 106 | 76 | 1.39 |
| 2 | 154 | 90 | 1.71 |
| 4 | 222 | 116 | 1.91 |
| 8 | 368 | 168 | 2.19 |

## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with YOLOS.
Expand Down

0 comments on commit 1c21f48

Please sign in to comment.