DALLE2-pytorch: Error related to text mask
File "/fsx/dalle2/.dalle_env_38/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 2340, in p_losses
return forward_call(*input, **kwargs)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 1817, in forward
model_output = unet(
File "/fsx/dalle2/.dalle_env_38/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (26) at non-singleton dimension 0
return forward_call(*input, **kwargs)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 1817, in forward
text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (26) at non-singleton dimension 0
Recently introduced but not sure which commit
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 49 (22 by maintainers)
awesome! ill give it a shot tomorrow, i’m gonna sign off for the afternoon.
@nousr threw in yet another stability measure https://github.com/lucidrains/DALLE2-pytorch/commit/349aaca56fe66a6f0fc6720c91ee5d7fa1e36f93
worst comes to worst, we can always rollback
thanks for letting me know about the dreaded
NaN
s!maybe, its more like all i get is NaN loss though–its not just a few
sure, I don’t really see any problem with that. I’m more just trying to get to the bottom of what’s causing the NaN values…
I’m a few minutes away from just running this on my local machine and stepping through with a debugger to see where the tensor is becoming NaN (subsequently heating the apartment)
@lucidrains I was getting NaN’s on the prior’s
.forward
when I rebased to main.Reverting to the following commit’s lines fixed the issue for me, just thought i’d throw it here
https://github.com/lucidrains/DALLE2-pytorch/blob/3dae43fa0ef2c5a29df9c275b1ebeb1a65c3bac5/dalle2_pytorch/dalle2_pytorch.py#L310-L318
https://github.com/lucidrains/DALLE2-pytorch/blob/3dae43fa0ef2c5a29df9c275b1ebeb1a65c3bac5/dalle2_pytorch/dalle2_pytorch.py#L866-L867
https://github.com/lucidrains/DALLE2-pytorch/blob/3dae43fa0ef2c5a29df9c275b1ebeb1a65c3bac5/dalle2_pytorch/dalle2_pytorch.py#L1177-L1178
ok done, no masks no worries 😄