-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an example of object segmentation (ClipSeg) (#801)
Co-authored-by: Jimin Ha <[email protected]> Co-authored-by: regisss <[email protected]>
- Loading branch information
1 parent
48b2c6d
commit 1ee7a47
Showing
6 changed files
with
272 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
<!--- | ||
Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
# Object Segmentation Examples | ||
|
||
This directory contains an example script that demonstrates how to perform object segmentation on Gaudi with graph mode. | ||
|
||
## Single-HPU inference | ||
|
||
```bash | ||
python3 run_example.py \ | ||
--model_name_or_path "CIDAS/clipseg-rd64-refined" \ | ||
--image_path "http://images.cocodataset.org/val2017/000000039769.jpg" \ | ||
--prompt "cat, remote, blanket" \ | ||
--warmup 3 \ | ||
--n_iterations 20 \ | ||
--use_hpu_graphs \ | ||
--bf16 \ | ||
--print_result | ||
``` | ||
Models that have been validated: | ||
- [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
|
||
# Copied from https://huggingface.co/docs/transformers/main/en/model_doc/clipseg | ||
|
||
import argparse | ||
import time | ||
|
||
import habana_frameworks.torch as ht | ||
import requests | ||
import torch | ||
from PIL import Image | ||
from torchvision.utils import save_image | ||
from transformers import AutoProcessor, CLIPSegForImageSegmentation | ||
|
||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--model_name_or_path", | ||
default="CIDAS/clipseg-rd64-refined", | ||
type=str, | ||
help="Path of the pre-trained model", | ||
) | ||
parser.add_argument( | ||
"--image_path", | ||
default="http://images.cocodataset.org/val2017/000000039769.jpg", | ||
type=str, | ||
help='Path of the input image. Should be a single string (eg: --image_path "URL")', | ||
) | ||
parser.add_argument( | ||
"--prompt", | ||
default="a cat,a remote,a blanket", | ||
type=str, | ||
help='Prompt for classification. It should be a string seperated by comma. (eg: --prompt "a photo of a cat, a photo of a dog")', | ||
) | ||
parser.add_argument( | ||
"--use_hpu_graphs", | ||
action="store_true", | ||
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.", | ||
) | ||
parser.add_argument( | ||
"--bf16", | ||
action="store_true", | ||
help="Whether to use bf16 precision for classification.", | ||
) | ||
parser.add_argument( | ||
"--print_result", | ||
action="store_true", | ||
help="Whether to print the classification results.", | ||
) | ||
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.") | ||
parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.") | ||
|
||
args = parser.parse_args() | ||
|
||
adapt_transformers_to_gaudi() | ||
|
||
processor = AutoProcessor.from_pretrained(args.model_name_or_path) | ||
model = CLIPSegForImageSegmentation.from_pretrained( | ||
args.model_name_or_path | ||
) # Use CLIPSegForImageSegmentation instead of automodel. | ||
# The output will contains the logits which are required to generated segmented images | ||
|
||
image = Image.open(requests.get(args.image_path, stream=True).raw) | ||
texts = [] | ||
for text in args.prompt.split(","): | ||
texts.append(text) | ||
|
||
if args.use_hpu_graphs: | ||
model = ht.hpu.wrap_in_hpu_graph(model) | ||
|
||
autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16) | ||
model.to("hpu") | ||
|
||
with torch.no_grad(), autocast: | ||
for i in range(args.warmup): | ||
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu") | ||
outputs = model(**inputs) | ||
torch.hpu.synchronize() | ||
|
||
total_model_time = 0 | ||
for i in range(args.n_iterations): | ||
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu") | ||
model_start_time = time.time() | ||
outputs = model(**inputs) | ||
torch.hpu.synchronize() | ||
model_end_time = time.time() | ||
total_model_time = total_model_time + (model_end_time - model_start_time) | ||
|
||
if args.print_result: | ||
if i == 0: # generate/output once only | ||
logits = outputs.logits | ||
for j in range(logits.shape[0]): | ||
threshold = 0.5 | ||
segmented_image = ((torch.sigmoid(logits[j]) > threshold) * 255).unsqueeze(0) | ||
segmented_image = segmented_image.to(torch.float32) | ||
save_image(segmented_image, "segmented_" + texts[j].strip() + ".png") | ||
print("Segmented images are generated.") | ||
|
||
print("n_iterations: " + str(args.n_iterations)) | ||
print("Total latency (ms): " + str(total_model_time * 1000)) | ||
print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import time | ||
from unittest import TestCase | ||
|
||
import habana_frameworks.torch as ht | ||
import numpy as np | ||
import requests | ||
import torch | ||
from PIL import Image | ||
from transformers import AutoModel, AutoProcessor | ||
|
||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi | ||
|
||
|
||
adapt_transformers_to_gaudi() | ||
|
||
# For Gaudi 2 | ||
LATENCY_ClipSeg_BF16_GRAPH_BASELINE = 5.3107380867004395 | ||
|
||
|
||
class GaudiClipSegTester(TestCase): | ||
""" | ||
Tests for ClipSeg model | ||
""" | ||
|
||
def prepare_model_and_processor(self): | ||
model = AutoModel.from_pretrained("CIDAS/clipseg-rd64-refined").to("hpu") | ||
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | ||
model = model.eval() | ||
return model, processor | ||
|
||
def prepare_data(self): | ||
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
image = Image.open(requests.get(url, stream=True).raw) | ||
texts = ["a cat", "a remote", "a blanket"] | ||
return texts, image | ||
|
||
def test_inference_default(self): | ||
model, processor = self.prepare_model_and_processor() | ||
texts, image = self.prepare_data() | ||
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu") | ||
outputs = model(**inputs) | ||
probs = outputs.logits_per_image.softmax(dim=-1).detach().cpu().numpy()[0] | ||
expected_scores = np.array([0.02889409, 0.87959206, 0.09151383]) # from CPU | ||
self.assertEqual(len(probs), 3) | ||
self.assertLess(np.abs(probs - expected_scores).max(), 0.01) | ||
|
||
def test_inference_autocast(self): | ||
model, processor = self.prepare_model_and_processor() | ||
texts, image = self.prepare_data() | ||
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu") | ||
|
||
with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16 | ||
outputs = model(**inputs) | ||
probs = outputs.logits_per_image.softmax(dim=-1).to(torch.float32).detach().cpu().numpy()[0] | ||
expected_scores = np.array([0.02889409, 0.87959206, 0.09151383]) # from CPU | ||
self.assertEqual(len(probs), 3) | ||
self.assertEqual(probs.argmax(), expected_scores.argmax()) | ||
|
||
def test_inference_hpu_graphs(self): | ||
model, processor = self.prepare_model_and_processor() | ||
texts, image = self.prepare_data() | ||
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu") | ||
|
||
model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph | ||
|
||
outputs = model(**inputs) | ||
probs = outputs.logits_per_image.softmax(dim=-1).to(torch.float32).detach().cpu().numpy()[0] | ||
expected_scores = np.array([0.02889409, 0.87959206, 0.09151383]) # from CPU | ||
self.assertEqual(len(probs), 3) | ||
self.assertEqual(probs.argmax(), expected_scores.argmax()) | ||
|
||
def test_no_latency_regression_autocast(self): | ||
warmup = 3 | ||
iterations = 20 | ||
|
||
model, processor = self.prepare_model_and_processor() | ||
texts, image = self.prepare_data() | ||
|
||
model = ht.hpu.wrap_in_hpu_graph(model) | ||
|
||
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True): | ||
for i in range(warmup): | ||
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to( | ||
"hpu" | ||
) | ||
_ = model(**inputs) | ||
torch.hpu.synchronize() | ||
|
||
total_model_time = 0 | ||
for i in range(iterations): | ||
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to( | ||
"hpu" | ||
) | ||
model_start_time = time.time() | ||
_ = model(**inputs) | ||
torch.hpu.synchronize() | ||
model_end_time = time.time() | ||
total_model_time = total_model_time + (model_end_time - model_start_time) | ||
|
||
latency = total_model_time * 1000 / iterations # in terms of ms | ||
self.assertLessEqual(latency, 1.05 * LATENCY_ClipSeg_BF16_GRAPH_BASELINE) |