diff --git a/Makefile b/Makefile
index 6e87a399a..bec72a57b 100644
--- a/Makefile
+++ b/Makefile
@@ -41,6 +41,11 @@ fast_tests_diffusers:
python -m pip install .[tests]
python -m pytest tests/test_diffusers.py
+# Run unit and integration tests related to Image segmentation
+fast_tests_image_segmentation:
+ python -m pip install .[tests]
+ python -m pytest tests/test_image_segmentation.py
+
# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
diff --git a/README.md b/README.md
index fabff9e26..9c88e1481 100644
--- a/README.md
+++ b/README.md
@@ -214,7 +214,7 @@ The following model architectures, tasks and device distributions have been vali
| OWLViT | |
Single card | [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
-
+| Segment Anything Model | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
- Diffusers:
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index b33cfd062..22b8dcba5 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -72,7 +72,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| OWLViT | | Single card | [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
-
+| SAM | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
- Diffusers
diff --git a/examples/object-segementation/README.md b/examples/object-segementation/README.md
index 4afb59849..fa1496a54 100644
--- a/examples/object-segementation/README.md
+++ b/examples/object-segementation/README.md
@@ -13,10 +13,12 @@ limitations under the License.
# Object Segmentation Examples
-This directory contains an example script that demonstrates how to perform object segmentation on Gaudi with graph mode.
+This directory contains two examples script that demonstrates how to perform object segmentation on Gaudi with graph mode.
## Single-HPU inference
+### ClipSeg Model
+
```bash
python3 run_example.py \
--model_name_or_path "CIDAS/clipseg-rd64-refined" \
@@ -29,4 +31,21 @@ python3 run_example.py \
--print_result
```
Models that have been validated:
- - [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
\ No newline at end of file
+ - [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
+
+### Segment Anything Model
+
+```bash
+python3 run_example_sam.py \
+ --model_name_or_path "facebook/sam-vit-huge" \
+ --image_path "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" \
+ --point_prompt "450,600" \
+ --warmup 3 \
+ --n_iterations 20 \
+ --use_hpu_graphs \
+ --bf16 \
+ --print_result
+```
+Models that have been validated:
+ - [facebook/sam-vit-base](https://huggingface.co/facebook/sam-vit-base)
+ - [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge)
\ No newline at end of file
diff --git a/examples/object-segementation/run_example_sam.py b/examples/object-segementation/run_example_sam.py
new file mode 100644
index 000000000..016b318be
--- /dev/null
+++ b/examples/object-segementation/run_example_sam.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# Copied from https://huggingface.co/facebook/sam-vit-base
+
+import argparse
+import time
+
+import habana_frameworks.torch as ht
+import requests
+import torch
+from PIL import Image
+from transformers import AutoModel, AutoProcessor
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ default="facebook/sam-vit-huge",
+ type=str,
+ help="Path of the pre-trained model",
+ )
+ parser.add_argument(
+ "--image_path",
+ default="https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png",
+ type=str,
+ help='Path of the input image. Should be a single string (eg: --image_path "URL")',
+ )
+ parser.add_argument(
+ "--point_prompt",
+ default="450, 600",
+ type=str,
+ help='Prompt for segmentation. It should be a string seperated by comma. (eg: --point_prompt "450, 600")',
+ )
+ parser.add_argument(
+ "--use_hpu_graphs",
+ action="store_true",
+ help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
+ )
+ parser.add_argument(
+ "--bf16",
+ action="store_true",
+ help="Whether to use bf16 precision for classification.",
+ )
+ parser.add_argument(
+ "--print_result",
+ action="store_true",
+ help="Whether to save the segmentation result.",
+ )
+ parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
+ parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")
+
+ args = parser.parse_args()
+
+ adapt_transformers_to_gaudi()
+
+ processor = AutoProcessor.from_pretrained(args.model_name_or_path)
+ model = AutoModel.from_pretrained(args.model_name_or_path)
+
+ image = Image.open(requests.get(args.image_path, stream=True).raw).convert("RGB")
+ points = []
+ for text in args.point_prompt.split(','):
+ points.append(int(text))
+ points = [[points]]
+
+ if args.use_hpu_graphs:
+ model = ht.hpu.wrap_in_hpu_graph(model)
+
+ autocast = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=args.bf16)
+ model.to("hpu")
+
+ with torch.no_grad(), autocast:
+ for i in range(args.warmup):
+ inputs = processor(image, input_points=points, return_tensors="pt").to("hpu")
+ outputs = model(**inputs)
+ torch.hpu.synchronize()
+
+ total_model_time = 0
+ for i in range(args.n_iterations):
+ inputs = processor(image, input_points=points, return_tensors="pt").to("hpu")
+ model_start_time = time.time()
+ outputs = model(**inputs)
+ torch.hpu.synchronize()
+ model_end_time = time.time()
+ total_model_time = total_model_time + (model_end_time - model_start_time)
+
+ if args.print_result:
+ if (i == 0): # generate/output once only
+ iou = outputs.iou_scores
+ print("iou score: " + str(iou))
+
+ print("n_iterations: " + str(args.n_iterations))
+ print("Total latency (ms): " + str(total_model_time*1000))
+ print("Average latency (ms): " + str(total_model_time*1000/args.n_iterations))
\ No newline at end of file
diff --git a/tests/test_image_segmentation.py b/tests/test_image_segmentation.py
new file mode 100644
index 000000000..cae5042af
--- /dev/null
+++ b/tests/test_image_segmentation.py
@@ -0,0 +1,114 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from unittest import TestCase
+
+import habana_frameworks.torch as ht
+import numpy as np
+import requests
+import torch
+from PIL import Image
+from transformers import AutoModel, AutoProcessor
+
+from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
+
+
+adapt_transformers_to_gaudi()
+
+# For Gaudi 2
+LATENCY_OWLVIT_BF16_GRAPH_BASELINE = 3.7109851837158203
+LATENCY_SAM_BF16_GRAPH_BASELINE = 98.92215728759766
+
+class GaudiSAMTester(TestCase):
+ """
+ Tests for Segment Anything Model - SAM
+ """
+ def prepare_model_and_processor(self):
+ model = AutoModel.from_pretrained("facebook/sam-vit-huge").to("hpu")
+ processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
+ model = model.eval()
+ return model, processor
+
+ def prepare_data(self):
+ image = Image.open(requests.get("https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png", stream=True).raw).convert("RGB")
+ input_points = [[[450, 600]]]
+ return input_points, image
+
+ def test_inference_default(self):
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+ outputs = model(**inputs)
+ scores = outputs.iou_scores
+ scores = scores[0][0]
+ expected_scores = np.array([0.9912, 0.9818, 0.9666])
+ self.assertEqual(len(scores), 3)
+ self.assertLess(np.abs(scores.cpu().detach().numpy() - expected_scores).max(), 0.02)
+
+ def test_inference_bf16(self):
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+
+ with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
+ outputs = model(**inputs)
+ scores = outputs.iou_scores
+ scores = scores[0][0]
+ expected_scores = np.array([0.9912, 0.9818, 0.9666])
+ self.assertEqual(len(scores), 3)
+ self.assertLess(np.abs(scores.to(torch.float32).cpu().detach().numpy() - expected_scores).max(), 0.02)
+
+ def test_inference_hpu_graphs(self):
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+
+ model = ht.hpu.wrap_in_hpu_graph(model) #Apply graph
+
+ outputs = model(**inputs)
+ scores = outputs.iou_scores
+ scores = scores[0][0]
+ expected_scores = np.array([0.9912, 0.9818, 0.9666])
+ self.assertEqual(len(scores), 3)
+ self.assertLess(np.abs(scores.to(torch.float32).cpu().detach().numpy() - expected_scores).max(), 0.02)
+
+ def test_no_latency_regression_bf16(self):
+ warmup = 3
+ iterations = 10
+
+ model, processor = self.prepare_model_and_processor()
+ input_points, image = self.prepare_data()
+
+ model = ht.hpu.wrap_in_hpu_graph(model)
+
+ with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
+ for i in range(warmup):
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+ _ = model(**inputs)
+ torch.hpu.synchronize()
+
+ total_model_time = 0
+ for i in range(iterations):
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
+ model_start_time = time.time()
+ _ = model(**inputs)
+ torch.hpu.synchronize()
+ model_end_time = time.time()
+ total_model_time = total_model_time + (model_end_time - model_start_time)
+
+ latency = total_model_time*1000/iterations # in terms of ms
+ self.assertLessEqual(latency, 1.05 * LATENCY_SAM_BF16_GRAPH_BASELINE)
+