-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
can not generate normal image with pretrained model #282
Comments
I initially also had trouble getting pretrained weights working properly with this repository but I resolved the problem. Not sure if you have the same problem, but I'll relay in case it helps. First, it was helpful to turn strict=True so that I could see the discrepancies between the weights and the models. The key to fixing my problem was noticing there were parameters defined in the model which did not exist in the pth file. Seeing that the latest pth file was ~8 months old, I downgraded dalle2-pytorch to version 1.1.0. Now the weights for prior works but the decoder pth was still missing parameters related to CLIP. I then took the state_dict from the pretrained CLIP and stuffed it inside of the decoder's pth-derived state_dict. Finally, I had to modify the dalle2_pytorch.py file as there seemed to be a bug in the DALLE2 class' forward method. On line 2940, image_embed is not set to decoder.sample's image_embed argument which produces an error when called. This bug has been fixed in more recent repo versions. But for 1.1.0, you need to modify the code:
Here is my script to get things running:
Let me know if you have any questions. Hope this helps. |
@cest-andre Hi! I am seeing the same error as LIUHAO121 even with the above, and wondering if I might be using an incorrect set of files (prior_config.json, decoder_config.json, decoder_latest.pth, and prior_latest.pth). I'm currently downloading these files from the huggingface repository -- is this correct? Also would the json files need to be modified after downloading (e.g. to fix paths?) |
@klei22 For prior, I am using latest.pth and prior_config.json in huggingface's prior folder. Note that in the code I posted, I modify the keys in decoder's state dictionary so that it recognizes the CLIP weights. |
@cest-andre thank you very much for sharing. It doesn't seem to work on my end either. Can you share an example of the input text and the image you have as output ? |
@tikitong Here are the results from the prompt 'a field of flowers': |
@cest-andre @tikitong I also made the same mistake. I first fixed the dalle2-pytorch to version 1.1.0. I also double checked this. Then modified the code on line 2940 in the dalle2_pytorch.py file. Finally, I downloaded the pre-train model and their config JSON, the same as you. With input text
No error reported
I got the result with a 64*64 2D image shown below: I have no idea about the issue now~ |
You didn't mention changing the keys in the decoder. This was something I mentioned and included in the code.
I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear. First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict. |
@cest-andre thanks too for your reply ! I have no error but the image does not match the input text. Do you have any idea where the problem might come from? Here is the code I used, the weight files are the ones you have indicated before (For prior, latest.pth and prior_config.json in huggingface's prior folder. For decoder, latest.pth and decoder_config.json in decoder/v1.0.2 folder). import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create()
prior_model_state = torch.load("weights/prior_latest.pth", map_location=torch.device('cpu'))
prior.load_state_dict(prior_model_state, strict=True)
decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create()
decoder_model_state = torch.load("weights/decoder_latest.pth", map_location=torch.device('cpu'))["model"]
for k in decoder.clip.state_dict().keys():
decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
decoder.load_state_dict(decoder_model_state, strict=True)
dalle2 = DALLE2(prior=prior, decoder=decoder)
images = dalle2(
['a red car'],
cond_scale = 2.
).cpu()
for img in images:
img = ToPILImage()(img)
img.show() Here is the configuration of my conda environnement :
|
@tikitong The model is non-deterministic, so you can run it multiple times and see if you get better images. But I think you're good to go. |
@cest-andre thanks again for your time ! |
@cest-andre
Can you help me to figure it out? |
You're not using the exact same code as you're trying to set your prior and decoder to your cpu. I've only ran everything on GPU so not sure what the exact problem is. Maybe you need to also set the DALLE2 object to cpu as well. |
@cest-andre |
@kdavidlp123 |
Hi, @tikitong sorry to bother you, where did you import
if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,
|
This worked for me, thank you! |
Hi, I solved this problem by deleting the folder named |
I can run by referring to your answer. Many thanks! |
@cest-andre can you please attach the updated code here for others to use as an initial reference, that will be a great help. |
The bottom of my first comment above has the code. |
@cest-andre I'm getting the following error while running the mentioned code: Traceback (most recent call last): For further information visit https://errors.pydantic.dev/2.0.3/u/root-validator-pre-skip Solution: |
Hi, i got an error from the decoder, I supouse is the dalle-pytorch version, does it actually works with the current file in the repository? |
This is my code for generating images, but the generated images are very blurry. prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior prior_model_state = torch.load("weights/prior_latest.pth") prior.load_state_dict(prior_model_state, strict=True) decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder decoder_model_state = torch.load("weights/decoder_latest.pth")["model"] for k in decoder.clip.state_dict().keys(): decoder.load_state_dict(decoder_model_state, strict=True) dalle2 = DALLE2(prior=prior, decoder=decoder).cuda() images = dalle2( |
where is the cuz folder,bro? |
I have the same problem, and I change the version to 1.1.0, there is also some error. |
this is my code for generate image,but the generated img is random。
prior model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/prior/best.pth
decoder model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/decoder/1.5B/latest.pth
The text was updated successfully, but these errors were encountered: