Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inquiry of code/result difference between SAM and GSAM #494

Open
TikaToka opened this issue Apr 29, 2024 · 0 comments
Open

inquiry of code/result difference between SAM and GSAM #494

TikaToka opened this issue Apr 29, 2024 · 0 comments

Comments

@TikaToka
Copy link
Contributor

TikaToka commented Apr 29, 2024

Hello, thank you for sharing amazing work!

I am trying to adapt GSAM code as an base model, but I have some inquiry.

from transformers import SamModel, SamProcessor

sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda" if torch.cuda.is_available() else "cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

sam_masks = []
for idx in range(preds.shape[0]):
  sam_inputs = sam_processor(image, input_points=[sampled_points[idx]], return_tensors="pt").to(device)

  with torch.no_grad():
      sam_outputs = sam_model(**sam_inputs)
      print(sam_outputs)
      print(sam_outputs.pred_masks.cpu().shape)

  sam_masks.append(sam_processor.image_processor.post_process_masks(
      sam_outputs.pred_masks.cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu()
  ))

for this code from SAM, each sam_mask has shape(1,1,3,h,w), total (n, 1, 1, 3, h, w)
However, if we use this code from GSAM,

image_pil, im = load_image(rgb_path)
# load model
model = load_model(config_file, grounded_checkpoint, device=device)

caption = generate_caption(image_path, device=device)
# Currently ", " is better for detecting single tags
# while ". " is a little worse in some case
text_prompt = generate_tags(caption, split=split)

boxes_filt, scores, pred_phrases = get_grounding_output(
    model, im, text_prompt, box_threshold, text_threshold, device='cuda'
)

print(boxes_filt, scores, pred_phrases)

# initialize SAM
# if use_sam_hq:
#     print("Initialize SAM-HQ Predictor")
#     predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(device))
# else:
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
im = cv2.imread(rgb_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
predictor.set_image(im)

size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
    boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
    boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
    boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
caption = check_caption(caption, pred_phrases)
print(f"Revise caption with number: {caption}")

transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, im.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
    point_coords = None,
    point_labels = None,
    boxes = transformed_boxes.to(device),
    multimask_output = False,
)

each mask's shape in masks is (1, h, w), total (n , 1, h, w)

I just wonder why there is a dimensional gap between SAM and GSAM, and is there a way to get a (1,1,3,w,h)?
I think it looks like 'Multimask_output=True', and if it is right, then the code might be:

new = [torch.tensor([[mask]]) for mask in masks.cpu().tolist()]

but I want to make sure of it.

Thank you in advance, and have a nice day!

@TikaToka TikaToka changed the title inquire of code/result difference between SAM and GSAM inquiry of code/result difference between SAM and GSAM Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant