Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross Attention maps #8

Open
danielajisafe opened this issue May 19, 2024 · 1 comment
Open

Cross Attention maps #8

danielajisafe opened this issue May 19, 2024 · 1 comment

Comments

@danielajisafe
Copy link

danielajisafe commented May 19, 2024

Hello,

Thank you so much for your great work and codebase!

I would appreciate your clarifications on a few items.

  1. From within TextToVideoSDPipelineCall.py, at this line, the attention maps from the temporal layers seem to be empty, by approximately using this code block
for name, module in self.unet.named_modules():
                module_name = type(module).__name__
                if module_name == "Attention" and "attn2" in name:
                
                #  --- First set
                    if "temp_attentions" in name:
                         print(name) # replace .0 with [0]
                         extracted_attention_map = module.processor.cross_attention_map
                         if extracted_attention_map!=None:
                             print(extracted_attention_map.shape)
                    else:
                        #  --- Second set
                        ...

- First set

down_blocks[0].temp_attentions[0].transformer_blocks[0].attn2
down_blocks[0].temp_attentions[1].transformer_blocks[0].attn2
down_blocks[1].temp_attentions[0].transformer_blocks[0].attn2
down_blocks[1].temp_attentions[1].transformer_blocks[0].attn2
down_blocks[2].temp_attentions[0].transformer_blocks[0].attn2
down_blocks[2].temp_attentions[1].transformer_blocks[0].attn2
up_blocks[1].temp_attentions[0].transformer_blocks[0].attn2
up_blocks[1].temp_attentions[1].transformer_blocks[0].attn2
up_blocks[1].temp_attentions[2].transformer_blocks[0].attn2
up_blocks[2].temp_attentions[0].transformer_blocks[0].attn2
up_blocks[2].temp_attentions[1].transformer_blocks[0].attn2
up_blocks[2].temp_attentions[2].transformer_blocks[0].attn2
up_blocks[3].temp_attentions[0].transformer_blocks[0].attn2
up_blocks[3].temp_attentions[1].transformer_blocks[0].attn2
up_blocks[3].temp_attentions[2].transformer_blocks[0].attn2
mid_block.temp_attentions[0].transformer_blocks[0].attn2

while only .attentions layers and the transformer_in layer in the second set have cross attention maps.

- Second set

transformer_in.transformer_blocks[0].attn2
torch.Size([64, 64, 24, 24])
down_blocks[0].attentions[0].transformer_blocks[0].attn2
torch.Size([120, 64, 64, 77])
down_blocks[0].attentions[1].transformer_blocks[0].attn2
torch.Size([120, 64, 64, 77])
down_blocks[1].attentions[0].transformer_blocks[0].attn2
torch.Size([240, 32, 32, 77])
down_blocks[1].attentions[1].transformer_blocks[0].attn2
torch.Size([240, 32, 32, 77])
down_blocks[2].attentions[0].transformer_blocks[0].attn2
torch.Size([480, 16, 16, 77])
down_blocks[2].attentions[1].transformer_blocks[0].attn2
torch.Size([480, 16, 16, 77])
up_blocks[1].attentions[0].transformer_blocks[0].attn2
torch.Size([480, 16, 16, 77])
up_blocks[1].attentions[1].transformer_blocks[0].attn2
torch.Size([480, 16, 16, 77])
up_blocks[1].attentions[2].transformer_blocks[0].attn2
torch.Size([480, 16, 16, 77])
up_blocks[2].attentions[0].transformer_blocks[0].attn2
torch.Size([240, 32, 32, 77])
up_blocks[2].attentions[1].transformer_blocks[0].attn2
torch.Size([240, 32, 32, 77])
up_blocks[2].attentions[2].transformer_blocks[0].attn2
torch.Size([240, 32, 32, 77])
up_blocks[3].attentions[0].transformer_blocks[0].attn2
torch.Size([120, 64, 64, 77])
up_blocks[3].attentions[1].transformer_blocks[0].attn2
torch.Size([120, 64, 64, 77])
up_blocks[3].attentions[2].transformer_blocks[0].attn2
torch.Size([120, 64, 64, 77])
mid_block.attentions[0].transformer_blocks[0].attn2
torch.Size([480, 8, 8, 77])

  1. If one should assume that the second set is the spatial attention maps, it does not align with modules listed in the supplemental document (page 1, screenshot included), particularly the transformer_in.transformer_blocks[0].attn2 with size 64, 64, 24, 24 suggesting its temporal (not spatial as mentioned in supplemental) with 24 frames and mid_block.attentions[0].transformer_blocks[0].attn2 with size 480, 8, 8, 77 suggesting its the spatial attention map (not temporal) with 77 tokens.

Your kind clarification would be very helpful. Thanks

@danielajisafe
Copy link
Author

Extract from page 1 in the supplemental document, hope to hear from you and thanks!

iScreen Shoter - Google Chrome - 240603154513

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant