Skip to content

Commit

Permalink
Aligned the file architecture by moving the files under object-segmen…
Browse files Browse the repository at this point in the history
…tation.

Used Automodel and related processor to replace model-specific API.
Improved the testing logic.
  • Loading branch information
cfgfung committed May 29, 2024
1 parent f436de1 commit b1a5102
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 185 deletions.
23 changes: 21 additions & 2 deletions examples/object-segementation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ limitations under the License.

# Object Segmentation Examples

This directory contains an example script that demonstrates how to perform object segmentation on Gaudi with graph mode.
This directory contains two example script that demonstrates how to perform object segmentation on Gaudi with graph mode.

## Single-HPU inference

### ClipSeg Model

```bash
python3 run_example.py \
--model_name_or_path "CIDAS/clipseg-rd64-refined" \
Expand All @@ -29,4 +31,21 @@ python3 run_example.py \
--print_result
```
Models that have been validated:
- [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)
- [clipseg-rd64-refined ](https://huggingface.co/CIDAS/clipseg-rd64-refined)

### Segment Anything Model

```bash
python3 run_example_sam.py \
--model_name_or_path "facebook/sam-vit-huge" \
--image_path "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" \
--point_prompt "450,600" \
--warmup 3 \
--n_iterations 20 \
--use_hpu_graphs \
--bf16 \
--print_result
```
Models that have been validated:
- [facebook/sam-vit-base](https://huggingface.co/facebook/sam-vit-base)
- [facebook/sam-vit-huge](https://huggingface.co/facebook/sam-vit-huge)
33 changes: 0 additions & 33 deletions examples/object-segementation/SegmentAnythingModel/README.md

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@

# Copied from https://huggingface.co/facebook/sam-vit-base

from transformers import SamModel, SamProcessor
from PIL import Image
import argparse
import time

import habana_frameworks.torch as ht
import requests
import torch
import habana_frameworks.torch as ht
import habana_frameworks.torch.core as htcore
import time
import argparse
from PIL import Image
from transformers import AutoModel, AutoProcessor

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


if __name__ == "__main__":
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -69,8 +70,8 @@

adapt_transformers_to_gaudi()

processor = SamProcessor.from_pretrained(args.model_name_or_path)
model = SamModel.from_pretrained(args.model_name_or_path)
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
model = AutoModel.from_pretrained(args.model_name_or_path)

image = Image.open(requests.get(args.image_path, stream=True).raw).convert("RGB")
points = []
Expand Down
35 changes: 14 additions & 21 deletions tests/test_image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import requests
from PIL import Image
import torch
import habana_frameworks.torch as ht
import habana_frameworks.torch.core as htcore
import time
import argparse
from transformers import OwlViTProcessor, OwlViTForObjectDetection, SamProcessor, SamModel
import unittest
from unittest import TestCase

import habana_frameworks.torch as ht
import numpy as np
import os
import requests
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


adapt_transformers_to_gaudi()

# For Gaudi 2
Expand All @@ -39,8 +37,8 @@ class GaudiSAMTester(TestCase):
Tests for Segment Anything Model - SAM
"""
def prepare_model_and_processor(self):
model = SamModel.from_pretrained("facebook/sam-vit-huge").to("hpu")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = AutoModel.from_pretrained("facebook/sam-vit-huge").to("hpu")
processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
model = model.eval()
return model, processor

Expand All @@ -54,7 +52,6 @@ def test_inference_default(self):
input_points, image = self.prepare_data()
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
scores = scores[0][0]
expected_scores = np.array([0.9912, 0.9818, 0.9666])
Expand All @@ -68,7 +65,6 @@ def test_inference_bf16(self):

with torch.autocast(device_type="hpu", dtype=torch.bfloat16): # Autocast BF16
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
scores = scores[0][0]
expected_scores = np.array([0.9912, 0.9818, 0.9666])
Expand All @@ -83,7 +79,6 @@ def test_inference_hpu_graphs(self):
model = ht.hpu.wrap_in_hpu_graph(model) #Apply graph

outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
scores = scores[0][0]
expected_scores = np.array([0.9912, 0.9818, 0.9666])
Expand All @@ -102,20 +97,18 @@ def test_no_latency_regression_bf16(self):
with torch.no_grad(), torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True):
for i in range(warmup):
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
outputs = model(**inputs)
_ = model(**inputs)
torch.hpu.synchronize()

total_model_time = 0
for i in range(iterations):
inputs = processor(image, input_points=input_points, return_tensors="pt").to("hpu")
model_start_time = time.time()
outputs = model(**inputs)
_ = model(**inputs)
torch.hpu.synchronize()
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

latency = total_model_time*1000/iterations # in terms of ms
self.assertGreaterEqual(latency, 0.95 * LATENCY_SAM_BF16_GRAPH_BASELINE)
self.assertLessEqual(latency, 1.05 * LATENCY_SAM_BF16_GRAPH_BASELINE)

# if __name__ == '__main__':
# unittest.main()
121 changes: 0 additions & 121 deletions tests/test_modelenabling.py

This file was deleted.

0 comments on commit b1a5102

Please sign in to comment.