DALLE2-pytorch: can not generate normal image with pretrained model

this is my code for generate image,but the generated img is random。 prior model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/prior/best.pth decoder model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/decoder/1.5B/latest.pth

import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch import Unet, Decoder,DALLE2

prior_network = DiffusionPriorNetwork(
    dim=768,
    depth=12,
    dim_head=64,
    heads=12,
    normformer=True,
    attn_dropout=5e-2,
    ff_dropout=5e-2,
    num_time_embeds=1,
    num_image_embeds=1,
    num_text_embeds=1,
    num_timesteps=1000,
    ff_mult=4,
    final_proj= True,
    rotary_emb= True
)

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=OpenAIClipAdapter("ViT-L/14"),
    image_embed_dim=768,
    timesteps=1000,
    # sample_timesteps = 64,
    cond_drop_prob=0.1,
    loss_type="l2",
    condition_on_text_encodings=True,

)

diffusion_prior.load_state_dict(torch.load("prior.pth",map_location=torch.device('cpu')),strict=False)

unet = Unet(
    **{"dim": 320,
    "cond_dim": 512,
    "image_embed_dim": 768,
    "text_embed_dim": 768,
    "cond_on_text_encodings": True,
    "channels": 3,
    "dim_mults": [1, 2, 3, 4],
    "num_resnet_blocks": 4,
    "attn_heads": 8,
    "attn_dim_head": 64,
    "sparse_attn": True,
    "memory_efficient": True,
    "self_attn": [False, True, True, True]}
)

decoder = Decoder(
    unet = unet,
    clip=OpenAIClipAdapter("ViT-L/14"),
    timesteps = 1000,
    image_sizes = [64],
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    learned_variance=True
)
decoder.load_state_dict(torch.load("decoder.pth",map_location=torch.device('cpu')),strict=False)

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale = 2., # classifier free guidance strength (> 1 would strengthen the condition)
    return_pil_images=True,
)
for img in images:
    img.save("out.jpg")
截屏2023-03-08 18 53 58

About this issue

Most upvoted comments

@ZhangxinruBIT

You didn’t mention changing the keys in the decoder. This was something I mentioned and included in the code.

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear.

First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict.

This worked for me, thank you!

@cest-andre thanks again for your time !

@tikitong I think you’ve got it working. The fact that you have a clean image that’s at least car related makes me think it’s working properly. The imperfect results are more of a function of the limitations of the model rather than any coding mistakes. I’ve also received some “off” results.

The model is non-deterministic, so you can run it multiple times and see if you get better images. But I think you’re good to go.

@tikitong Sure thing. A reminder that this fix requires downgrading dalle2-pytorch to version 1.1.0. Could you be more specific about what doesn’t work on your end? Do you get an error or do you get bad image results?

Here are the results from the prompt ‘a field of flowers’:

field_of_flowers