Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled DETR (Object Detection) model #1046

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ fast_tests_diffusers:
python -m pip install .[tests]
python -m pytest tests/test_diffusers.py

# Run unit and integration tests related to Image segmentation
fast_tests_object_detection:
python -m pip install .[tests]
python -m pip install timm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if no_timm is used for the test, is this needed?

python -m pytest tests/test_object_detection.py

# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ The following model architectures, tasks and device distributions have been vali
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |
| Llava / Llava-next | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| DETR | | <div style="text-align:left"><li>Single card</li></div> | <li>[object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)</li> |

</div>

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |
| Llava / Llava-next | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| DETR | | <div style="text-align:left"><li>Single card</li></div> | <li>[object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)</li> |


- Diffusers
Expand Down
34 changes: 34 additions & 0 deletions examples/object-detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<!---
Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Object Detection Example

This folder contains an example script which demonstrates the usage of DETR to run object detection task on Gaudi platform.

## Single-HPU inference

```bash
python3 run_example.py \
--model_name_or_path facebook/detr-resnet-101 \
--image_path "http://images.cocodataset.org/val2017/000000039769.jpg" \
--use_hpu_graphs \
--bf16 \
--print_result
```

Models that have been validated:
- [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101)
- [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)
114 changes: 114 additions & 0 deletions examples/object-detection/run_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# Copied from https://huggingface.co/docs/transformers/model_doc/owlvit

import argparse
import time

import habana_frameworks.torch as ht
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, DetrForObjectDetection

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name_or_path",
default="facebook/detr-resnet-101",
type=str,
help="Path of the pre-trained model",
)
parser.add_argument(
"--image_path",
default="http://images.cocodataset.org/val2017/000000039769.jpg",
type=str,
help='Path of the input image. Should be a single string (eg: --image_path "URL")',
)
parser.add_argument(
"--use_hpu_graphs",
action="store_true",
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
)
parser.add_argument(
"--bf16",
action="store_true",
help="Whether to use bf16 precision for object detection.",
)
parser.add_argument(
"--print_result",
action="store_true",
help="Whether to print the detection results.",
)

parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
parser.add_argument(
"--n_iterations", type=int, default=10, help="Number of inference iterations for benchmarking."
)

args = parser.parse_args()

adapt_transformers_to_gaudi()

# you can specify the revision tag if you don't want the timm dependency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how to approach this here, but would be nice to able to pass/toggle between no_timm and main revisions.

processor = AutoProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm")

image = Image.open(requests.get(args.image_path, stream=True).raw)

inputs = processor(images=image, return_tensors="pt").to("hpu")
model.to("hpu")

if args.use_hpu_graphs:
model = ht.hpu.wrap_in_hpu_graph(model)

autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16)

with torch.no_grad(), autocast:
for i in range(args.warmup):
inputs = processor(images=image, return_tensors="pt").to("hpu")
outputs = model(**inputs)
torch.hpu.synchronize()

total_model_time = 0
for i in range(args.n_iterations):
inputs = processor(images=image, return_tensors="pt").to("hpu")
model_start_time = time.time()
outputs = model(**inputs)
torch.hpu.synchronize()
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

if args.print_result:
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)

print("n_iterations: " + str(args.n_iterations))
print("Total latency (ms): " + str(total_model_time * 1000))
print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations))
4 changes: 4 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
gaudi_conv1d_forward,
gaudi_DetrConvModel_forward,
gaudi_esm_for_protein_folding_forward,
gaudi_esmfolding_trunk_forward,
gaudi_falcon_linear_forward,
Expand Down Expand Up @@ -519,5 +520,8 @@ def adapt_transformers_to_gaudi():
gaudi_owlvitclasspredictionhead_forward
)

# Optimization for DETR model on Gaudi
transformers.models.detr.modeling_detr.DetrConvModel.forward = gaudi_DetrConvModel_forward

# Tell transformers which Gaudi models support tracing
transformers.utils.fx._SUPPORTED_MODELS += tuple(cls.__name__ for cls in models_with_tracing_support)
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
)
from .detr import gaudi_DetrConvModel_forward
from .esm import (
gaudi_esm_for_protein_folding_forward,
gaudi_esmfolding_trunk_forward,
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/detr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .modeling_detr import (
gaudi_DetrConvModel_forward,
)
19 changes: 19 additions & 0 deletions optimum/habana/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def gaudi_DetrConvModel_forward(self, pixel_values, pixel_mask):
"""
Copied from modeling_detr: https://github.com/huggingface/transformers/blob/main/src/transformers/models/detr/modeling_detr.py#L398
The modications are:
- Use CPU to calculate the position_embeddings and transfer back to HPU
"""

# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
out = self.conv_encoder(pixel_values, pixel_mask)
pos = []
self.position_embedding = self.position_embedding.to("cpu")

for feature_map, mask in out:
# position encoding
feature_map = feature_map.to("cpu")
mask = mask.to("cpu")
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype).to("hpu"))

return out, pos
121 changes: 121 additions & 0 deletions tests/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
from unittest import TestCase

import habana_frameworks.torch as ht
import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, DetrForObjectDetection

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


adapt_transformers_to_gaudi()

if os.environ.get("GAUDI2_CI", "0") == "1":
# Gaudi2 CI baselines
LATENCY_DETR_BF16_GRAPH_BASELINE = 7.593865966796875
else:
# Gaudi1 CI baselines
LATENCY_DETR_BF16_GRAPH_BASELINE = 15.25988267912151


class GaudiDETRTester(TestCase):
"""
Tests for Object Detection - DETR
"""

def prepare_model_and_processor(self):
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101").to("hpu")
model = model.eval()
processor = AutoProcessor.from_pretrained("facebook/detr-resnet-101")
return model, processor

def prepare_data(self):
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
return image

def test_inference_default(self):
model, processor = self.prepare_model_and_processor()
image = self.prepare_data()
inputs = processor(images=image, return_tensors="pt").to("hpu")
outputs = model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
boxes = results["boxes"]
self.assertEqual(len(boxes), 5)
expected_location = np.array([344.0622, 24.8543, 640.3398, 373.7401])
self.assertLess(np.abs(boxes[0].cpu().detach().numpy() - expected_location).max(), 1)

def test_inference_autocast(self):
model, processor = self.prepare_model_and_processor()
image = self.prepare_data()
inputs = processor(images=image, return_tensors="pt").to("hpu")

with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
outputs = model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
boxes = results["boxes"]
self.assertEqual(len(boxes), 5)
expected_location = np.array([342, 25.25, 636, 376])
self.assertLess(np.abs(boxes[0].to(torch.float32).cpu().detach().numpy() - expected_location).max(), 5)

def test_inference_hpu_graphs(self):
model, processor = self.prepare_model_and_processor()
image = self.prepare_data()
inputs = processor(images=image, return_tensors="pt").to("hpu")

model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph

outputs = model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)
boxes = results[0]["boxes"]
self.assertEqual(len(boxes), 5)
expected_location = np.array([344.0622, 24.8543, 640.3398, 373.7401])
self.assertLess(np.abs(boxes[0].to(torch.float32).cpu().detach().numpy() - expected_location).max(), 1)

def test_no_latency_regression_autocast(self):
warmup = 3
iterations = 10

model, processor = self.prepare_model_and_processor()
image = self.prepare_data()

model = ht.hpu.wrap_in_hpu_graph(model)

with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
for i in range(warmup):
inputs = processor(images=image, return_tensors="pt").to("hpu")
_ = model(**inputs)
torch.hpu.synchronize()

total_model_time = 0
for i in range(iterations):
inputs = processor(images=image, return_tensors="pt").to("hpu")
model_start_time = time.time()
_ = model(**inputs)
torch.hpu.synchronize()
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

latency = total_model_time * 1000 / iterations # in terms of ms
self.assertLessEqual(latency, 1.05 * LATENCY_DETR_BF16_GRAPH_BASELINE)