imagen-pytorch: fp16 causes loss to be nan?

Hi! When using the fp16 option, my loss becomes nan. I’m using a V100. Is there any other option I need to configure besides fp16=True in the trainer?

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 1
  • Comments: 37 (26 by maintainers)

Most upvoted comments

@vedantroy you should let gwern know that you are training on danbooru 😆

@vedantroy yea, pytorch autocast will take care of converting the types between the boundaries of code

@vedantroy thanks for pressing on this! i believe i have found the problem and the fix is in 1.1.6

@vedantroy one last thing to try would be to set a warmup for the learning rate https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/trainer.py#L201

you can try 10000 steps for starters

@vedantroy doh, yea, i think your code looks ok

do you want to try v1.1.5? 5fca687

Hm, that didn’t fix the issue. 😕 Would love to assist with debugging but not sure how. If you want, I can DM you ssh access to my machine / are there any diagnostics you would want me to run on my machine?

@lucidrains

@vedantroy ok! if you run into NaNs later on, try turning that on

if it works even with it off, then it is fine, you can just keep it that way

I’m happy to turn cosine_sim_attn=True if there’s some experiment you want to run. But I’m guessing the experiment requires me hitting NaNs first?

@vedantroy at minimum it should be 16 or 32

Got it, so what is going on here is that I’m updating the gradients every 32 samples, but processing data in batches of 8.

And no, I’m not using cosine_sim_attn. I’ll restart the run now with that enabled.

try increasing the batch size to 32, then use max_batch_size of 8 or 4 to keep it from OOM

That worked! Would love if you gave an explanation of what is going on here. I assumed max_batch_size was just splitting the input into smaller batches underneath the hood, so the end result would be the same?

yup, max batch size just makes sure it never processes more than that size at any time, preventing the memory from spilling over the limit!

try increasing the batch size to 32, then use max_batch_size of 8 or 4 to keep it from OOM

@jacobwjs thank you Jacob! 🙏

@lucidrains ok, simple enough. i’m maxed out on compute until tomorrow or the following day, but will kick off something when a gpu free’s up if vedantroy doesn’t beat me to it 😃

very cool! can you think of a metric that with and without it would give you a good idea if it’s beneficial. basically how do I test this and know if it has a positive impact?

@vedantroy Hi Vedant again 👋

Could you keep your current settings, and try turning bdb6ee0#diff-edef3c5fe92797a22c0b8fc6cca1d57b4b84ef03dfdfe802ed9147e21fc88109R1036 on?

@lucidrains is this meant to simply scale attention?

I’m training 100’s of thousands of steps using fp16. No nan’s to speak of. Perhaps look at gradients, and if there’s an issue clip https://github.com/lucidrains/imagen-pytorch/blob/e1a8dc258246bae465d489754afacea8dd9bbf43/imagen_pytorch/trainer.py#L198