diffusers: Dreambooth broken, possibly because of ADAM optimizer, possibly more.
I think Huggingface’s Dreambooth is the only popular SD implementation that also uses Prior Preservation Loss, so I’ve been motivated to get it working, but the results have been terrible, and the entire model degrades, regardless of: # timesteps, learning rate, PPL turned on/off, # instance samples, # class regularization samples, etc. I’ve read the paper, and found that they actually unfreeze everything including the text embedder (and VAE? I’m not sure so I leave it frozen), so I implemented textual inversion within the dreambooth example (new token, unfreeze a single row of the embedder), which improves results considerably, but the whole model still degrades no matter what.
Someone smarter than me can confirm, but I think the culprit is ADAM:
My hypothesis is that since ADAM tries to drag all weights of unet etc. to 0, it ruins parts of the model that aren’t concurrently being trained during the finetuning.
I’ve tested with weight_decay
set to 0, and results seem considerably better, but I think the entire model is still degrading. I’m trying SGD next, so, fingers crossed, but there may still be some dragon lurking in the depths even despite removing ADAM.
A point of reference on this journey is the JoePenna “Dreambooth” library which doesn’t implement PPL, and yet preserves priors much much better than this example, not to mention it learns the instance better, and is far more editable, and preserves out-of-class rather well. I expect more from this huggingface dreambooth example, and I’m trying to find why it’s not delivering.
Any thoughts or guidance?
EDIT1A: SGD didn’t learn the instance at 1000 steps + lr=5e-5, but it definitely preserved the priors way better (upon visual inspection. The loss really doesn’t decrease much in any of my inversion/dreambooth experiments).
EDIT1B: Another test failed to learn using SGD at 1500 steps + lr=1e-3 + momentum=0.9. It might be trying to learn, but, not much. Priors were nicely preserved though still.
EDIT1C: 1500 * lr=5e2 learned great, was editable, didn’t destroy other priors!!!
EDIT2: JoePenna seems to use AdamW, so I’m not sure what’s up anymore, but I’m still getting quite poor results training with this library’s (huggingface’s) DB example.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 6
- Comments: 41 (11 by maintainers)
Hi everyone! Sorry to be so late here.
We ran a lot of experiments with the script to see if there’s any issues or if it’s broken. Turns out, we need to carefully pick hyperparameters like LR and training steps to get good results with dreambooth.
Also, the main reason the results with this script were not as good as Compvis forks, is that the
text_encoder
is being trained in those forks, and it makes a big difference in quality especially on faces.We compiled all our experiments in this report, and also added an option to train the
text_encoder
in the script which can be enabled by passing the--train_text_encoder
argument.Note that, if we train the
text_encoder
the training won’t fit on 16GB GPU anymore, it will need at least 24GB VRAM. It should still be possible to do it 16GB using deepspeed, but will be slower.Please take a look at the report, hope you find it useful.
Think @patil-suraj will very soon release a nice update of dreambooth 😃
@jslegers Have you by chance tried the JoePenna repo? I’m still trying to pin down the difference, but, I think it works better, and I don’t know why.
They’re starting to look identical, so I don’t know where the difference in perceptual quality lies. The code is semantically identical, I think; JoePennas lib:
decay_weight
, leaving it to pytorches1e-2
defaultnum_train_timesteps
in DDPM? That might be different between the 2 libs? (diffusers=1000 vs joepenna=1)EDIT: we must get this library to work, the Shivam fork of it can happily train 1024x1024 by 3 batches at a time on a 24GB card.
EDIT2: Could that DDPM
num_timesteps
be the issue?The paper
Fig 7 also elucidates timesteps a bit.
So if I understand the paper, and these 2 repos, HF-diffusers runs 1000 steps of noise, and says “try and denoise that, bwahaha!”, whereas JoePenna runs 1 step of noise saying “let’s play on easy mode”.
EDIT3: I tried a short 10min round of training with
num_timesteps=1
instead of1000
, and the loss is a helluva lot smoother. Before it was perfect zigzags, now, after an initial transient, it’s perfectly monotonic decreasing. That’s a good sign.@patil-suraj I know you’re slammed, (I see you on every support ticket on this repo!), but any chance you’ve had an opportunity to ask the Dreambooth authors about the catastrophic forgetting + overfitting we’ve been facing? My pet theory was it has to do with weight decay, but, whatever it is, I’m just curious! No rush of course, and thanks for all your hard work!
Been testing using my face all day with a variety of optimisers because reasons 😃 Things learnt so far -
Default Johnny Depp (seed 1201043):
With DB model (defaults), 200 preservation images, fp16, 1000 steps:
I can really see my features in there!
With DB model, 1200 preservation images, fp16, 1600 steps, 4e-6:
Not so much me 😃
Thanks a lot @affableroots and sorry for being later here. Running some experiments today, also comparing against other codebases. Will post my findings here soon.
I did some testing regarding the impact of Dreambooth on different prompts, using the same seed.
Pretty much all of my tests produced results similar to this, when running running Dreambooth with class “man” and concept “johnslegers” :
I’ve tried using different config, but to no avail. The degradation persists no matter how many input pics I use, how many class pics I use, what value I use for prior preservation, etc.