Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example of object segmentation (ClipSeg) #801

Merged
merged 12 commits into from
May 20, 2024
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ slow_tests_trl: test_installs
python -m pip install peft==0.7.0
python -m pytest tests/test_trl.py -v -s -k "test_calculate_loss"

slow_tests_object_segmentation: test_installs
python -m pytest tests/test_object_segmentation.py

# Check if examples are up to date with the Transformers library
example_diff_tests: test_installs
python -m pytest tests/test_examples_match_transformers.py
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ The following model architectures, tasks and device distributions have been vali
| ESMFold | | <div style="text-align:left"><li>Single card</li></div> | <li>[protein folding](https://github.com/huggingface/optimum-habana/tree/main/examples/protein-folding)</li> |
| Blip | | <div style="text-align:left"><li>Single card</li></div> | <li>[visual question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/visual-question-answering)</li><li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |

</div>

Expand Down Expand Up @@ -241,4 +242,4 @@ Please refer to Habana Gaudi's official [installation guide](https://docs.habana

## Development

Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions.
Check the [contributor guide](https://github.com/huggingface/optimum/blob/main/CONTRIBUTING.md) for instructions.
4 changes: 3 additions & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| ESMFold | | <div style="text-align:left"><li>Single card</li></div> | <li>[protein folding](https://github.com/huggingface/optimum-habana/tree/main/examples/protein-folding)</li> |
| Blip | | <div style="text-align:left"><li>Single card</li></div> | <li>[visual question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/visual-question-answering)</li><li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| OWLViT | | <div style="text-align:left"><li>Single card</li></div> | <li>[zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)</li> |
| ClipSeg | | <div style="text-align:left"><li>Single card</li></div> | <li>[object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)</li> |


- Diffusers

Expand Down Expand Up @@ -113,4 +115,4 @@ Besides, [this page](https://github.com/huggingface/optimum-habana/tree/main/exa
<p class="text-gray-700">Technical descriptions of how the Habana classes and methods of 🤗 Optimum Habana work.</p>
</a>
</div>
</div>
</div>
32 changes: 32 additions & 0 deletions examples/object-segementation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<!---
Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Object Segmentation Examples

This directory contains an example script that demonstrates how to perform object segmentation on Gaudi with graph mode.

## Single-HPU inference

```bash
python3 run_example.py \
--model_name_or_path "CIDAS/clipseg-rd64-refined" \
--image_path "http://images.cocodataset.org/val2017/000000039769.jpg" \
--prompt "cat, remote, blanket" \
--warmup 3 \
--n_iterations 20 \
--use_hpu_graphs \
--bf16 \
--print_result
```
Models that have been validated:
- [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
118 changes: 118 additions & 0 deletions examples/object-segementation/run_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# Copied from https://huggingface.co/docs/transformers/main/en/model_doc/clipseg

import argparse
import time

import habana_frameworks.torch as ht
import requests
import torch
from PIL import Image
from torchvision.utils import save_image
from transformers import AutoProcessor, CLIPSegForImageSegmentation

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name_or_path",
default="CIDAS/clipseg-rd64-refined",
type=str,
help="Path of the pre-trained model",
)
parser.add_argument(
"--image_path",
default="http://images.cocodataset.org/val2017/000000039769.jpg",
type=str,
help='Path of the input image. Should be a single string (eg: --image_path "URL")',
)
parser.add_argument(
"--prompt",
default="a cat,a remote,a blanket",
type=str,
help='Prompt for classification. It should be a string seperated by comma. (eg: --prompt "a photo of a cat, a photo of a dog")',
)
parser.add_argument(
"--use_hpu_graphs",
action="store_true",
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
)
parser.add_argument(
"--bf16",
action="store_true",
help="Whether to use bf16 precision for classification.",
)
parser.add_argument(
"--print_result",
action="store_true",
help="Whether to print the classification results.",
)
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")

args = parser.parse_args()

adapt_transformers_to_gaudi()

processor = AutoProcessor.from_pretrained(args.model_name_or_path)
model = CLIPSegForImageSegmentation.from_pretrained(
args.model_name_or_path
) # Use CLIPSegForImageSegmentation instead of automodel.
# The output will contains the logits which are required to generated segmented images

image = Image.open(requests.get(args.image_path, stream=True).raw)
texts = []
for text in args.prompt.split(","):
texts.append(text)

if args.use_hpu_graphs:
model = ht.hpu.wrap_in_hpu_graph(model)

autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16)
model.to("hpu")

with torch.no_grad(), autocast:
for i in range(args.warmup):
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu")
outputs = model(**inputs)
torch.hpu.synchronize()

total_model_time = 0
for i in range(args.n_iterations):
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu")
model_start_time = time.time()
outputs = model(**inputs)
torch.hpu.synchronize()
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

if args.print_result:
if i == 0: # generate/output once only
logits = outputs.logits
for j in range(logits.shape[0]):
threshold = 0.5
segmented_image = ((torch.sigmoid(logits[j]) > threshold) * 255).unsqueeze(0)
segmented_image = segmented_image.to(torch.float32)
save_image(segmented_image, "segmented_" + texts[j].strip() + ".png")
print("Segmented images are generated.")

print("n_iterations: " + str(args.n_iterations))
print("Total latency (ms): " + str(total_model_time * 1000))
print("Average latency (ms): " + str(total_model_time * 1000 / args.n_iterations))
114 changes: 114 additions & 0 deletions tests/test_object_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from unittest import TestCase

import habana_frameworks.torch as ht
import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


adapt_transformers_to_gaudi()

# For Gaudi 2
LATENCY_ClipSeg_BF16_GRAPH_BASELINE = 5.3107380867004395


class GaudiClipSegTester(TestCase):
"""
Tests for ClipSeg model
"""

def prepare_model_and_processor(self):
model = AutoModel.from_pretrained("CIDAS/clipseg-rd64-refined").to("hpu")
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = model.eval()
return model, processor

def prepare_data(self):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
texts = ["a cat", "a remote", "a blanket"]
return texts, image

def test_inference_default(self):
model, processor = self.prepare_model_and_processor()
texts, image = self.prepare_data()
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu")
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=-1).detach().cpu().numpy()[0]
expected_scores = np.array([0.02889409, 0.87959206, 0.09151383]) # from CPU
self.assertEqual(len(probs), 3)
self.assertLess(np.abs(probs - expected_scores).max(), 0.01)

def test_inference_autocast(self):
model, processor = self.prepare_model_and_processor()
texts, image = self.prepare_data()
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu")

with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=-1).to(torch.float32).detach().cpu().numpy()[0]
expected_scores = np.array([0.02889409, 0.87959206, 0.09151383]) # from CPU
self.assertEqual(len(probs), 3)
self.assertEqual(probs.argmax(), expected_scores.argmax())

def test_inference_hpu_graphs(self):
model, processor = self.prepare_model_and_processor()
texts, image = self.prepare_data()
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to("hpu")

model = ht.hpu.wrap_in_hpu_graph(model) # Apply graph

outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=-1).to(torch.float32).detach().cpu().numpy()[0]
expected_scores = np.array([0.02889409, 0.87959206, 0.09151383]) # from CPU
self.assertEqual(len(probs), 3)
self.assertEqual(probs.argmax(), expected_scores.argmax())

def test_no_latency_regression_autocast(self):
warmup = 3
iterations = 20

model, processor = self.prepare_model_and_processor()
texts, image = self.prepare_data()

model = ht.hpu.wrap_in_hpu_graph(model)

with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
for i in range(warmup):
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to(
"hpu"
)
_ = model(**inputs)
torch.hpu.synchronize()

total_model_time = 0
for i in range(iterations):
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to(
"hpu"
)
model_start_time = time.time()
_ = model(**inputs)
torch.hpu.synchronize()
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

latency = total_model_time * 1000 / iterations # in terms of ms
self.assertLessEqual(latency, 1.05 * LATENCY_ClipSeg_BF16_GRAPH_BASELINE)
Loading