diffusers: Training on zero terminal SNR seems to be broken
Describe the bug
I’ve used the Diffusers pix2pix script, adapted for training the SDXL U-net with a focus on text2img to do experiments over the last week.
I’ve observed a continual contrast issue with SDXL on v_prediction.
Prompt: white woman in white pants and white shirt standing in front of white background
Versus SD 2.0-v (not 2.1)
After fine-tuning
SD 2.0-v trained on portrait aspect
Versus SD 2.1
(before terminal SNR fine-tuning)
during terminal SNR fine-tuning:
after fine-tuning (sorry for different resolutions, that doesn’t impact results):
Darkness prompt
SD 2.0-v during fine-tuning, mild contrast issue
SD 2.1-v after fine-tuning, perfect contrast
SDXL 1.0
Reproduction
I’ve tried pretty much every tunable flag and code modification that made sense, and even some that didn’t.
I’ve connected to other trainers who are familiar with terminal SNR training, that are just as lost and confused as I am.
This occurs:
- When training SDXL on
prediction_type='v_prediction'
. I have not triedsample
(x-prediction). - With or without EMA
- With any value of guidance_rescaling from 0.0 to 1.0
- With any value of guidance scaling, from 1.0 to 20.0
- Over many training days (1 week) or a quick-burn with a high learning rate
- LR scheduling has zero effect on this problem
- With low or high batch size and zero or more gradient accumulation steps
- Min SNR weighting doesn’t have any impact on contrast, though it encouraged convergence
- Offset noise causes more stability issues. It’s possibly adding noise on a zero-sigma timestep?
- Input perturbations don’t seem to help at all.
- When using
timestep_spacing='leading'
ortimestep_spacing='trailing'
. I have not triedlinspace
. - When using DDPMScheduler for training with
trained_betas
originating from DDIMScheduler with rescale parameter set. - At any training resolution.
- When using AdamW8Bit or D-Adaptation (external dependency)
I have not been able to try disabling gradient checkpointing, which is theorised in some communities to have an impact on v_prediction adaptation - if it did, that would be a bug. I haven’t been able to confirm.
I have heard anecdotal evidence that other trainers (eg. Khoya) have the same issue with SDXL v-prediction.
It’s possible that this will simply never work, but that doesn’t feel correct.
Logs
No response
System Info
diffusers
version: 0.20.0.dev0- Platform: Linux-5.19.0-45-generic-x86_64-with-glibc2.31
- Python version: 3.9.16
- PyTorch version (GPU?): 2.1.0.dev20230730+cu118 (True)
- Huggingface_hub version: 0.16.4
- Transformers version: 4.30.2
- Accelerate version: 0.18.0
- xFormers version: 0.0.21+82254f4.d20230731
- Using GPU in script?: A100-80G
- Using distributed or parallel set-up in script?: No
Who can help?
About this issue
- Original URL
- State: closed
- Created 10 months ago
- Reactions: 1
- Comments: 28 (28 by maintainers)
I actually totally rewrote the bucketing logic this weekend, so that others as yourself can read it more easily with less time invested in following the breadcrumbs.
the current plan is to finish up a proof-of-concept checkpoint here, and then make a few follow-up fixes to the dataset sampler code, before settling on a longer training run, with a greater quantity of images.
SDXL trains impressively fast. EMA weights and min-SNR are a godsend. you don’t get v-prediction adaptation this quickly with AdamW8Bit without min-SNR.
375 steps into training:
Our WandB logs (ignore the incorrect tag in the project name, this uses AdamW8Bit, not Dadapt) https://wandb.ai/bghira/sdxl-v-cocoPhoto87k-minSNR-dadapt-norm
Evidently bypassing the image cropping transforms meant I lost the nornalization too. this one is 100% on me, i caused this problem. thank you all for the help.