Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example of FastViT model (Infernece) #826

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ fast_tests_diffusers:
python -m pip install .[tests]
python -m pytest tests/test_diffusers.py

# Run single-card non-regression tests on image classification models
fast_tests_image_classifications:
pip install timm
python -m pytest tests/test_image_classification.py

# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ The following model architectures, tasks and device distributions have been vali

</div>

- PyTorch Image Models/TIMM:

<div align="center">

| Architecture | Training | Inference | Tasks |
|---------------------|:--------:|:---------:|:------|
| FastViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification)</li> |

</div>

- TRL:

<div align="center">
Expand Down
6 changes: 5 additions & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,18 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Llava / Llava-next | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |


- Diffusers
- Diffusers:

| Architecture | Training | Inference | Tasks |
|---------------------|:--------:|:---------:|:------|
| Stable Diffusion | <li>[textual inversion](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#textual-inversion)</li><li>[ControlNet](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#controlnet-training)</li> | <div style="text-align:left"><li>Single card</li></div> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)</li> |
| Stable Diffusion XL | <li>[fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#fine-tuning-for-stable-diffusion-xl)</li> | <div style="text-align:left"><li>Single card</li></div> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)</li> |
| LDM3D | | <div style="text-align:left"><li>Single card</li></div> | <li>[text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)</li> |

- PyTorch Image Models/TIMM:
| Architecture | Training | Inference | Tasks |
|---------------------|:--------:|:---------:|:------|
| FastViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification)</li> |

- TRL:

Expand Down
23 changes: 22 additions & 1 deletion examples/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.

# Image Classification Examples

This directory contains a script that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit) or [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)) on HPUs. They can be used to fine-tune models on both [datasets from the hub](#using-datasets-from-hub) as well as on [your own custom data](#using-your-own-data).
This directory contains a script that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit) or [Swin Transformer](https://huggingface.co/docs/transformers/main/en/model_doc/swin)) on HPUs. They can be used to fine-tune models on both [datasets from the hub](#using-datasets-from-hub) as well as on [your own custom data](#using-your-own-data). This directory also contains a script to demonstrate a single HPU inference for [PyTorch-Image-Models/TIMM](https://huggingface.co/docs/timm/index)


## Requirements
Expand Down Expand Up @@ -295,3 +295,24 @@ python run_image_classification.py \
--gaudi_config_name Habana/vit \
--dataloader_num_workers 1 \
--bf16

## TIMM/FastViT Examples

This directory contains an example script that demonstrates using FastViT with graph mode.

### Single-HPU inference

```bash
python3 run_timm_example.py \
--model_name_or_path "timm/fastvit_t8.apple_in1k" \
--image_path "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" \
--warmup 3 \
--n_iterations 20 \
--use_hpu_graphs \
--bf16 \
--print_result
```
Models that have been validated:
- [timm/fastvit_t8.apple_dist_in1k](https://huggingface.co/timm/fastvit_t8.apple_dist_in1k)
- [timm/fastvit_t8.apple_in1k](https://huggingface.co/timm/fastvit_t8.apple_in1k)
- [timm/fastvit_sa12.apple_in1k](https://huggingface.co/timm/fastvit_sa12.apple_in1k)
1 change: 1 addition & 0 deletions examples/image-classification/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ torchvision>=0.6.0
datasets>=2.14.0
evaluate
scikit-learn
timm>=0.9.16
102 changes: 102 additions & 0 deletions examples/image-classification/run_timm_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# Copied from https://huggingface.co/timm/fastvit_t8.apple_in1k

import argparse
import time

import habana_frameworks.torch as ht
import requests
import timm
import torch
from PIL import Image

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name_or_path",
default="timm/fastvit_t8.apple_in1k",
type=str,
help="Path of the pre-trained model",
)
parser.add_argument(
"--image_path",
default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png",
type=str,
help='Path of the input image. Should be a single string (eg: --image_path "URL")',
)
parser.add_argument(
"--use_hpu_graphs",
action="store_true",
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
)
parser.add_argument(
"--bf16",
action="store_true",
help="Whether to use bf16 precision for classification.",
)
parser.add_argument(
"--print_result",
action="store_true",
help="Whether to print the classification results.",
)
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")

args = parser.parse_args()

adapt_transformers_to_gaudi()

model = timm.create_model(args.model_name_or_path, pretrained=True)
model.to("hpu")
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

img = Image.open(requests.get(args.image_path, stream=True).raw)

if args.use_hpu_graphs:
model = ht.hpu.wrap_in_hpu_graph(model)

autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16)
model.to("hpu")

with torch.no_grad(), autocast:
for i in range(args.warmup):
inputs = transforms(img).unsqueeze(0).to("hpu")
outputs = model(inputs)
torch.hpu.synchronize()

total_model_time = 0
for i in range(args.n_iterations):
inputs = transforms(img).unsqueeze(0).to("hpu")
model_start_time = time.time()
outputs = model(inputs)
torch.hpu.synchronize()
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

if args.print_result:
top5_probabilities, top5_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=5)
print("top5_class_indices: " + str(top5_class_indices.to("cpu").numpy()))

print("n_iterations: " + str(args.n_iterations))
print("Total latency (ms): " + str(total_model_time * 1000))
print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations))
120 changes: 120 additions & 0 deletions tests/test_image_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from unittest import TestCase

import habana_frameworks.torch as ht
import numpy as np
import requests
import timm
import torch
from PIL import Image

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


adapt_transformers_to_gaudi()

# For Gaudi 2
LATENCY_FastViT_BF16_GRAPH_BASELINE = 2.5270626640319824


class GaudiFastViTTester(TestCase):
"""
Tests for FastViT model
"""

def prepare_model_and_processor(self):
model = timm.create_model("timm/fastvit_t8.apple_in1k", pretrained=True)
model.to("hpu")
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
processor = timm.data.create_transform(**data_config, is_training=False)
return model, processor

def prepare_data(self):
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
image = Image.open(requests.get(url, stream=True).raw)
return image

def test_inference_default(self):
model, processor = self.prepare_model_and_processor()
image = self.prepare_data()
inputs = processor(image).unsqueeze(0).to("hpu")
outputs = model(inputs)
top1_probabilities, top1_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=1)
top1_probabilities = top1_probabilities.to("cpu").detach().numpy()
top1_class_indices = top1_class_indices.to("cpu").numpy()
expected_scores = np.array([21.406523]) # from CPU
expected_class = np.array([960])
self.assertEqual(top1_class_indices, expected_class)
self.assertLess(np.abs(top1_probabilities - expected_scores).max(), 1)

def test_inference_autocast(self):
model, processor = self.prepare_model_and_processor()
image = self.prepare_data()
inputs = processor(image).unsqueeze(0).to("hpu")

with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
outputs = model(inputs)
top1_probabilities, top1_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=1)
top1_probabilities = top1_probabilities.to("cpu").detach().numpy()
top1_class_indices = top1_class_indices.to("cpu").numpy()
expected_scores = np.array([21.406523]) # from CPU
expected_class = np.array([960])
self.assertEqual(top1_class_indices, expected_class)
self.assertLess(np.abs(top1_probabilities - expected_scores).max(), 1)

def test_inference_hpu_graphs(self):
model, processor = self.prepare_model_and_processor()
image = self.prepare_data()
inputs = processor(image).unsqueeze(0).to("hpu")

model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph

outputs = model(inputs)
top1_probabilities, top1_class_indices = torch.topk(outputs.softmax(dim=1) * 100, k=1)
top1_probabilities = top1_probabilities.to("cpu").detach().numpy()
top1_class_indices = top1_class_indices.to("cpu").numpy()
expected_scores = np.array([21.406523]) # from CPU
expected_class = np.array([960])
self.assertEqual(top1_class_indices, expected_class)
self.assertLess(np.abs(top1_probabilities - expected_scores).max(), 1)

def test_no_latency_regression_autocast(self):
warmup = 3
iterations = 20

model, processor = self.prepare_model_and_processor()
image = self.prepare_data()

model = ht.hpu.wrap_in_hpu_graph(model)

with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
for i in range(warmup):
inputs = processor(image).unsqueeze(0).to("hpu")
_ = model(inputs)
torch.hpu.synchronize()

total_model_time = 0
for i in range(iterations):
inputs = processor(image).unsqueeze(0).to("hpu")
model_start_time = time.time()
_ = model(inputs)
torch.hpu.synchronize()
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

latency = total_model_time * 1000 / iterations # in terms of ms
self.assertLessEqual(latency, 1.05 * LATENCY_FastViT_BF16_GRAPH_BASELINE)