-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabled DETR (Object Detection) model #1046
Open
cfgfung
wants to merge
4
commits into
huggingface:main
Choose a base branch
from
cfgfung:examples/detr_resnet
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
<!--- | ||
Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
# Object Detection Example | ||
|
||
This folder contains an example script which demonstrates the usage of DETR to run object detection task on Gaudi platform. | ||
|
||
## Single-HPU inference | ||
|
||
```bash | ||
python3 run_example.py \ | ||
--model_name_or_path facebook/detr-resnet-101 \ | ||
--image_path "http://images.cocodataset.org/val2017/000000039769.jpg" \ | ||
--use_hpu_graphs \ | ||
--bf16 \ | ||
--print_result | ||
``` | ||
|
||
Models that have been validated: | ||
- [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101) | ||
- [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
|
||
# Copied from https://huggingface.co/docs/transformers/model_doc/owlvit | ||
|
||
import argparse | ||
import time | ||
|
||
import habana_frameworks.torch as ht | ||
import requests | ||
import torch | ||
from PIL import Image | ||
from transformers import AutoProcessor, DetrForObjectDetection | ||
|
||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--model_name_or_path", | ||
default="facebook/detr-resnet-101", | ||
type=str, | ||
help="Path of the pre-trained model", | ||
) | ||
parser.add_argument( | ||
"--image_path", | ||
default="http://images.cocodataset.org/val2017/000000039769.jpg", | ||
type=str, | ||
help='Path of the input image. Should be a single string (eg: --image_path "URL")', | ||
) | ||
parser.add_argument( | ||
"--use_hpu_graphs", | ||
action="store_true", | ||
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.", | ||
) | ||
parser.add_argument( | ||
"--bf16", | ||
action="store_true", | ||
help="Whether to use bf16 precision for object detection.", | ||
) | ||
parser.add_argument( | ||
"--print_result", | ||
action="store_true", | ||
help="Whether to print the detection results.", | ||
) | ||
|
||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.") | ||
parser.add_argument( | ||
"--n_iterations", type=int, default=10, help="Number of inference iterations for benchmarking." | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
adapt_transformers_to_gaudi() | ||
|
||
# you can specify the revision tag if you don't want the timm dependency | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure how to approach this here, but would be nice to able to pass/toggle between |
||
processor = AutoProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm") | ||
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm") | ||
|
||
image = Image.open(requests.get(args.image_path, stream=True).raw) | ||
|
||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
model.to("hpu") | ||
|
||
if args.use_hpu_graphs: | ||
model = ht.hpu.wrap_in_hpu_graph(model) | ||
|
||
autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16) | ||
|
||
with torch.no_grad(), autocast: | ||
for i in range(args.warmup): | ||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
outputs = model(**inputs) | ||
torch.hpu.synchronize() | ||
|
||
total_model_time = 0 | ||
for i in range(args.n_iterations): | ||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
model_start_time = time.time() | ||
outputs = model(**inputs) | ||
torch.hpu.synchronize() | ||
model_end_time = time.time() | ||
total_model_time = total_model_time + (model_end_time - model_start_time) | ||
|
||
if args.print_result: | ||
# convert outputs (bounding boxes and class logits) to COCO API | ||
# let's only keep detections with score > 0.9 | ||
target_sizes = torch.tensor([image.size[::-1]]) | ||
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | ||
|
||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | ||
box = [round(i, 2) for i in box.tolist()] | ||
print( | ||
f"Detected {model.config.id2label[label.item()]} with confidence " | ||
f"{round(score.item(), 3)} at location {box}" | ||
) | ||
|
||
print("n_iterations: " + str(args.n_iterations)) | ||
print("Total latency (ms): " + str(total_model_time * 1000)) | ||
print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .modeling_detr import ( | ||
gaudi_DetrConvModel_forward, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
def gaudi_DetrConvModel_forward(self, pixel_values, pixel_mask): | ||
""" | ||
Copied from modeling_detr: https://github.com/huggingface/transformers/blob/main/src/transformers/models/detr/modeling_detr.py#L398 | ||
The modications are: | ||
- Use CPU to calculate the position_embeddings and transfer back to HPU | ||
""" | ||
|
||
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples | ||
out = self.conv_encoder(pixel_values, pixel_mask) | ||
pos = [] | ||
self.position_embedding = self.position_embedding.to("cpu") | ||
|
||
for feature_map, mask in out: | ||
# position encoding | ||
feature_map = feature_map.to("cpu") | ||
mask = mask.to("cpu") | ||
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype).to("hpu")) | ||
|
||
return out, pos |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 HuggingFace Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import time | ||
from unittest import TestCase | ||
|
||
import habana_frameworks.torch as ht | ||
import numpy as np | ||
import requests | ||
import torch | ||
from PIL import Image | ||
from transformers import AutoProcessor, DetrForObjectDetection | ||
|
||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi | ||
|
||
|
||
adapt_transformers_to_gaudi() | ||
|
||
if os.environ.get("GAUDI2_CI", "0") == "1": | ||
# Gaudi2 CI baselines | ||
LATENCY_DETR_BF16_GRAPH_BASELINE = 7.593865966796875 | ||
else: | ||
# Gaudi1 CI baselines | ||
LATENCY_DETR_BF16_GRAPH_BASELINE = 15.25988267912151 | ||
|
||
|
||
class GaudiDETRTester(TestCase): | ||
""" | ||
Tests for Object Detection - DETR | ||
""" | ||
|
||
def prepare_model_and_processor(self): | ||
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101").to("hpu") | ||
model = model.eval() | ||
processor = AutoProcessor.from_pretrained("facebook/detr-resnet-101") | ||
return model, processor | ||
|
||
def prepare_data(self): | ||
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) | ||
return image | ||
|
||
def test_inference_default(self): | ||
model, processor = self.prepare_model_and_processor() | ||
image = self.prepare_data() | ||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
outputs = model(**inputs) | ||
target_sizes = torch.Tensor([image.size[::-1]]) | ||
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | ||
boxes = results["boxes"] | ||
self.assertEqual(len(boxes), 5) | ||
expected_location = np.array([344.0622, 24.8543, 640.3398, 373.7401]) | ||
self.assertLess(np.abs(boxes[0].cpu().detach().numpy() - expected_location).max(), 1) | ||
|
||
def test_inference_autocast(self): | ||
model, processor = self.prepare_model_and_processor() | ||
image = self.prepare_data() | ||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
|
||
with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16 | ||
outputs = model(**inputs) | ||
target_sizes = torch.Tensor([image.size[::-1]]) | ||
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | ||
boxes = results["boxes"] | ||
self.assertEqual(len(boxes), 5) | ||
expected_location = np.array([342, 25.25, 636, 376]) | ||
self.assertLess(np.abs(boxes[0].to(torch.float32).cpu().detach().numpy() - expected_location).max(), 5) | ||
|
||
def test_inference_hpu_graphs(self): | ||
model, processor = self.prepare_model_and_processor() | ||
image = self.prepare_data() | ||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
|
||
model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph | ||
|
||
outputs = model(**inputs) | ||
target_sizes = torch.Tensor([image.size[::-1]]) | ||
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1) | ||
boxes = results[0]["boxes"] | ||
self.assertEqual(len(boxes), 5) | ||
expected_location = np.array([344.0622, 24.8543, 640.3398, 373.7401]) | ||
self.assertLess(np.abs(boxes[0].to(torch.float32).cpu().detach().numpy() - expected_location).max(), 1) | ||
|
||
def test_no_latency_regression_autocast(self): | ||
warmup = 3 | ||
iterations = 10 | ||
|
||
model, processor = self.prepare_model_and_processor() | ||
image = self.prepare_data() | ||
|
||
model = ht.hpu.wrap_in_hpu_graph(model) | ||
|
||
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True): | ||
for i in range(warmup): | ||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
_ = model(**inputs) | ||
torch.hpu.synchronize() | ||
|
||
total_model_time = 0 | ||
for i in range(iterations): | ||
inputs = processor(images=image, return_tensors="pt").to("hpu") | ||
model_start_time = time.time() | ||
_ = model(**inputs) | ||
torch.hpu.synchronize() | ||
model_end_time = time.time() | ||
total_model_time = total_model_time + (model_end_time - model_start_time) | ||
|
||
latency = total_model_time * 1000 / iterations # in terms of ms | ||
self.assertLessEqual(latency, 1.05 * LATENCY_DETR_BF16_GRAPH_BASELINE) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if no_timm is used for the test, is this needed?