Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provided pooled_prompt_embeds is overwritten via prompt_embeds[0] #7365

Open
cloneofsimo opened this issue Mar 18, 2024 · 10 comments · May be fixed by #7926
Open

Provided pooled_prompt_embeds is overwritten via prompt_embeds[0] #7365

cloneofsimo opened this issue Mar 18, 2024 · 10 comments · May be fixed by #7926

Comments

@cloneofsimo
Copy link
Contributor

Simple fix:

pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is not None else pooled_prompt_embeds

Sorry this isn't a pr :P

@sayakpaul
Copy link
Member

Golden catch. Please PR it.

@bghira

This comment was marked as outdated.

@bghira

This comment was marked as outdated.

@AmericanPresidentJimmyCarter
Copy link
Contributor

So I was unsure how it was pooled before, and we were going through the code trying to figure it out. It seems like

pooled_prompt_embeds = prompt_embeds[0]

Should really be

            pooled_prompt_embeds = list(filter(lambda x: x is not None, [
                getattr(prompt_embeds, 'pooler_output', None),
                getattr(prompt_embeds, 'text_embeds', None),
            ]))[0]

For clarity, as the two CLIP models output completely different classes and contain their pooled outputs in different properties. My concern was originally that instead of using the pooled output in the case of the one CLIP model, we were actually selecting the first token with [0] if that was the output of the last hidden layer instead of the pooled output. Simply referencing it as [0] is extremely unclear given the nature of the output from the two CLIP models.

@sayakpaul
Copy link
Member

Order in which the tokenizers and text encoders are being passed matters, so I think the implementation is correct. If any comment would help, please file a PR, more than happy to work on a priority on that.

Simply referencing it as [0] is extremely unclear given the nature of the output from the two CLIP models.

That reason it's there is because it helps with torch.compile() otherwise there's a TensorVariableAccess problem.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 17, 2024
@bghira
Copy link
Contributor

bghira commented Apr 17, 2024

not stale

@github-actions github-actions bot removed the stale Issues that haven't received updates label Apr 18, 2024
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 12, 2024
@bghira
Copy link
Contributor

bghira commented May 12, 2024

still not stale :D

bghira pushed a commit to bghira/diffusers that referenced this issue May 12, 2024
bghira pushed a commit to bghira/diffusers that referenced this issue May 12, 2024
bghira pushed a commit to bghira/diffusers that referenced this issue May 12, 2024
@bghira
Copy link
Contributor

bghira commented May 12, 2024

@sayakpaul i opened the pull request for this. but the code in question only runs when prompt embeds are None.

do we want to mix and match provided pooled embeds with generated prompt embeds?

@github-actions github-actions bot removed the stale Issues that haven't received updates label May 13, 2024
bghira pushed a commit to bghira/diffusers that referenced this issue May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants