diffusers: Training on zero terminal SNR seems to be broken

Describe the bug

I’ve used the Diffusers pix2pix script, adapted for training the SDXL U-net with a focus on text2img to do experiments over the last week.

I’ve observed a continual contrast issue with SDXL on v_prediction.

Prompt: white woman in white pants and white shirt standing in front of white background image

Versus SD 2.0-v (not 2.1)

After fine-tuning image

SD 2.0-v trained on portrait aspect image

Versus SD 2.1

(before terminal SNR fine-tuning) image

during terminal SNR fine-tuning: image

after fine-tuning (sorry for different resolutions, that doesn’t impact results): image

Darkness prompt

SD 2.0-v during fine-tuning, mild contrast issue image

SD 2.1-v after fine-tuning, perfect contrast image

SDXL 1.0 image

Reproduction

I’ve tried pretty much every tunable flag and code modification that made sense, and even some that didn’t.

I’ve connected to other trainers who are familiar with terminal SNR training, that are just as lost and confused as I am.

This occurs:

  • When training SDXL on prediction_type='v_prediction'. I have not tried sample (x-prediction).
  • With or without EMA
  • With any value of guidance_rescaling from 0.0 to 1.0
  • With any value of guidance scaling, from 1.0 to 20.0
  • Over many training days (1 week) or a quick-burn with a high learning rate
  • LR scheduling has zero effect on this problem
  • With low or high batch size and zero or more gradient accumulation steps
  • Min SNR weighting doesn’t have any impact on contrast, though it encouraged convergence
  • Offset noise causes more stability issues. It’s possibly adding noise on a zero-sigma timestep?
  • Input perturbations don’t seem to help at all.
  • When using timestep_spacing='leading' or timestep_spacing='trailing'. I have not tried linspace.
  • When using DDPMScheduler for training with trained_betas originating from DDIMScheduler with rescale parameter set.
  • At any training resolution.
  • When using AdamW8Bit or D-Adaptation (external dependency)

I have not been able to try disabling gradient checkpointing, which is theorised in some communities to have an impact on v_prediction adaptation - if it did, that would be a bug. I haven’t been able to confirm.

I have heard anecdotal evidence that other trainers (eg. Khoya) have the same issue with SDXL v-prediction.

It’s possible that this will simply never work, but that doesn’t feel correct.

Logs

No response

System Info

  • diffusers version: 0.20.0.dev0
  • Platform: Linux-5.19.0-45-generic-x86_64-with-glibc2.31
  • Python version: 3.9.16
  • PyTorch version (GPU?): 2.1.0.dev20230730+cu118 (True)
  • Huggingface_hub version: 0.16.4
  • Transformers version: 4.30.2
  • Accelerate version: 0.18.0
  • xFormers version: 0.0.21+82254f4.d20230731
  • Using GPU in script?: A100-80G
  • Using distributed or parallel set-up in script?: No

Who can help?

@williamberman and @PeterL1n

About this issue

  • Original URL
  • State: closed
  • Created 10 months ago
  • Reactions: 1
  • Comments: 28 (28 by maintainers)

Most upvoted comments

I actually totally rewrote the bucketing logic this weekend, so that others as yourself can read it more easily with less time invested in following the breadcrumbs.

the current plan is to finish up a proof-of-concept checkpoint here, and then make a few follow-up fixes to the dataset sampler code, before settling on a longer training run, with a greater quantity of images.

SDXL trains impressively fast. EMA weights and min-SNR are a godsend. you don’t get v-prediction adaptation this quickly with AdamW8Bit without min-SNR.

375 steps into training:

image

Our WandB logs (ignore the incorrect tag in the project name, this uses AdamW8Bit, not Dadapt) https://wandb.ai/bghira/sdxl-v-cocoPhoto87k-minSNR-dadapt-norm

image

Evidently bypassing the image cropping transforms meant I lost the nornalization too. this one is 100% on me, i caused this problem. thank you all for the help.