DALLE2-pytorch: Error related to text mask

File "/fsx/dalle2/.dalle_env_38/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 2340, in p_losses
return forward_call(*input, **kwargs)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 1817, in forward
model_output = unet(
File "/fsx/dalle2/.dalle_env_38/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (26) at non-singleton dimension 0
return forward_call(*input, **kwargs)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 1817, in forward
text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (26) at non-singleton dimension 0

Recently introduced but not sure which commit

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 49 (22 by maintainers)

Most upvoted comments

awesome! ill give it a shot tomorrow, i’m gonna sign off for the afternoon.

On Jul 12, 2022, at 5:51 PM, Phil Wang @.***> wrote:

@nousr threw in yet another stability measure 349aaca

worst comes to worst, we can always rollback

thanks for letting me know about the dreaded NaNs!

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.

@nousr threw in yet another stability measure https://github.com/lucidrains/DALLE2-pytorch/commit/349aaca56fe66a6f0fc6720c91ee5d7fa1e36f93

worst comes to worst, we can always rollback

thanks for letting me know about the dreaded NaNs!

@nousr yea, it could just be usual transformers instability

maybe, its more like all i get is NaN loss though–its not just a few

@nousr i think we should just follow the new masking rule, which is that any text encoding must be padded with zeros during the dataloading collation

sure, I don’t really see any problem with that. I’m more just trying to get to the bottom of what’s causing the NaN values…

I’m a few minutes away from just running this on my local machine and stepping through with a debugger to see where the tensor is becoming NaN (subsequently heating the apartment)

ok done, no masks no worries 😄