Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logits of inteference #302

Open
Corallo opened this issue Mar 6, 2024 · 3 comments
Open

Logits of inteference #302

Corallo opened this issue Mar 6, 2024 · 3 comments

Comments

@Corallo
Copy link

Corallo commented Mar 6, 2024

Hi, If I understand correctly, the logits are the feature of the object detected.

I was wondering why in the inference script you extract only the first element, making them unusable, instead of keeping them all.
https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/util/inference.py#L97

@heyoeyo
Copy link

heyoeyo commented Mar 6, 2024

The max function returns a listing of the maximum values and the corresponding indices of those max values, at least when using it with dim=1. So (I think) the [0] indexing is used to get only the max values themselves (and not the index of those values), which is likely meant to be used as a confidence 'score' for each of the boxes.

@Corallo
Copy link
Author

Corallo commented Mar 8, 2024

@heyoeyo I am not sure I follow, what are outputs["pred_logits"] supposed to be in the first place?

@heyoeyo
Copy link

heyoeyo commented Mar 8, 2024

As far as I understand, the logits are meant to be a numerical representation of info about each of the bounding boxes predicted by the model. They can be thought of as just being an array of (256) numbers, one array for each bounding box.

They seem to use the logits in 3 (similar) ways in that function:

  1. They're used to keep only the 'good' bounding box predictions. A box is considered 'good' if the largest of the (256) logit values is above some box_threshold value.
  2. They're used to figure out which part of the text prompt goes with each box. This happens in the get_phrases_from_posmap(...) function calls. It works in a similar way, where logit values above some text_threshold indicate which part of the input text should be assigned to the box.
  3. They're used as an overall 'score' of the quality of the box + text label (the part that you linked)

Anyways, the easiest way to understand the weird indexing on the output is to try it with some sample data. You can try running something like:

import torch
logits = torch.randint(0,10,(6,2))

print(logits)
# tensor([[5, 7],
#         [0, 9],
#         [2, 0],
#         [9, 7],
#         [6, 6],
#         [9, 0]])

print(logits.max(dim=1)[0])
# tensor([7, 9, 2, 9, 6, 9])

print(logits.max(dim=1)[1])
# tensor([1, 1, 0, 0, 0, 0])

You can see the .max(dim=1)[0] just gives the maximum number along each row. If you do .max(dim=1)[1] you get the column position of each of those max values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants