Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPAdapter not compatible to torch.compile #7985

Closed
rootonchair opened this issue May 19, 2024 · 3 comments
Closed

IPAdapter not compatible to torch.compile #7985

rootonchair opened this issue May 19, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@rootonchair
Copy link
Contributor

rootonchair commented May 19, 2024

Describe the bug

When use IPAdapter with torch.compile will return error message below:

torch._dynamo.exc.Unsupported: UNPACK_SEQUENCE NNModuleVariable()

Due to the use of this line

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py#L879

Which can be fixed with a patch like:

    for l in self.layers:
        residual = latents

        comps = []
        for i in l:
            comps.append(i)
        [ln0, ln1, attn, ff] = comps

        encoder_hidden_states = ln0(x)
        latents = ln1(latents)
        encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
        latents = attn(latents, encoder_hidden_states) + residual
        latents = ff(latents) + latents

Reproduction

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors"])
pipeline.set_ip_adapter_scale(0.6)
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipeline.unet, mode="max-autotune", fullgraph=True)

Logs

No response

System Info

  • diffusers version: 0.28.0.dev0
  • Platform: Linux-4.18.0-408.el8.x86_64-x86_64-with-glibc2.17
  • Python version: 3.8.13
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Huggingface_hub version: 0.21.4
  • Transformers version: 4.40.0.dev0
  • Accelerate version: 0.28.0
  • xFormers version: 0.0.22.post7
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@sayakpaul @yiyixuxu

@rootonchair rootonchair added the bug Something isn't working label May 19, 2024
@sayakpaul
Copy link
Member

Thanks for sending over the patch, too. Would you like to take a stab at sending over a PR?

@rootonchair
Copy link
Contributor Author

Sure, PR on the way

@sayakpaul
Copy link
Member

Closing with #7994

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants