diffusers: SDXL dreambooth can't be resumed from a checkpoint at fp16 training

Describe the bug

train_dreambooth_lora_sdxl.py can’t be resumed from a checkpoint using fp16. The log error is Attempting to unscale FP16 gradients.

This is a big blocker from being able to train on the free colab tier since you need fp16 to fit in vram, but also need to resume from checkpoints since it can hit a timeout at any moment.

Reproduction

Reproduce with: https://colab.research.google.com/drive/15woNcXcpsa3GDGk6cmDtIL2V8zRtOOj3

Logs

No response

System Info

latest diffusers, system is whatever is on colab (see linked colab above)

Who can help?

@patrickvonplaten @sayakpaul

About this issue

  • Original URL
  • State: closed
  • Created 10 months ago
  • Reactions: 2
  • Comments: 39 (21 by maintainers)

Most upvoted comments

Have been solved with: https://github.com/huggingface/diffusers/pull/6514. So, I am actually going to close this.

@epi-morphism

Okay training is actually not broken. Turned out to be an LR thing.

Colab: https://colab.research.google.com/gist/sayakpaul/13864eb0427bef50f5e95f08b60a03a3/scratchpad.ipynb

WandB: https://wandb.ai/xin-sayak/dreambooth-lora-sd-xl/runs/qb1rc2pn?workspace=user-sayakpaul

I am training on PT 2.0 as well to see if it’s working as expected. Will update here.

But currently, I don’t have any handle on why resuming the training from an FP16 checkpoint doesn’t work as expected. Maybe if we could have exact same parameters (used initially) that could help prevent the bug? I will run this myself and see.

https://gist.github.com/sayakpaul/0b5c13f2020ac5ee7a3547a0f8ccd781 – you can check that even non-resumed runs lead to broken LoRAs. Corresponding run page: https://wandb.ai/xin-sayak/dreambooth-lora-sd-xl/runs/axim7chx.

I was able to fix the prior version of the file, but the new commit of the example with the new lora layers returns another issue when using mixed_precision + resuming checkpoint. I will investigate more later.

  File "/home/marc/anaconda3/envs/accelerate/lib/python3.11/site-packages/torch/cuda/amp/grad_scaler.py", line 372, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: No inf checks were recorded for this optimizer.

After investigation, it seems like it is an issue on diffusers side. When we resume the checkpoint, we load back the unet lora weights. It uses successively the following functions load_model_hook, load_lora_into_unet and load_attn_procs. In load_attn_procs, the entire unet with lora weight will be converted to the dtype of the unet. The problem is that in the script, we cast all non-trainable weigths to fp16 or bf16 for mixed precision training setup (vae, non-lora text_encoder and non-lora unet).

# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)

So in the end, the lora layers end up being converted also in load_attn_procs. See code below or the link to the code source.

# set lora layers
for target_module, lora_layer in lora_layers_list:
    target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)

To solve this issue, we can either modify load_attn_procs to set the lora_layers after this line self.to(dtype=self.dtype, device=self.device) or convert the lora_layers back to fp32 in the script. LMK what which option you prefer @patrickvonplaten , it will do a PR to fix that.

Hi @danieltanhx , indeed these is an issue with the optimizer where the gradient are somehow calculated in dtype torch.float16 instead of torch.float32 after resuming the training. We will work on fixing this asap cc @muellerzr .

temporary solution is to brute force into /usr/local/lib/python3.10/dist-packages/torch/cuda/amp/grad_scaler.py change line 284 from “optimizer_state[“found_inf_per_device”] = self.unscale_grads(optimizer, inv_scale, found_inf, False)” to “optimizer_state[“found_inf_per_device”] = self.unscale_grads(optimizer, inv_scale, found_inf, True)”