diffusers: [examples] ]Error due to mismatches in the dtype of UNet while running `train_text_to_image.py`
Describe the bug
I am referring to this script: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py.
To perform qualitative validation and following https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py, I have added some sample inference code:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
temp_unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=temp_unet,
revision=args.revision,
torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(
args.seed
)
images = []
for _ in range(args.num_validation_images):
images.append(
pipeline(
args.validation_prompt,
num_inference_steps=30,
generator=generator,
).images[0]
)
Training runs smoothly but when it gets to the pipeline calling part, it runs into an error complaining about the mismatches in the dtypes of the model (UNet) and the inputs.
Cc: @kashif
Reproduction
Code: https://gist.github.com/sayakpaul/97ec0007ec423960912ca8822e4cb7be
Command:
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=250 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--validation_prompt="cute dragon creature" \
--seed=666 \
--output_dir="sd-pokemon-model"
Logs
No response
System Info
- `diffusers` version: 0.13.0.dev0
- Platform: Linux-4.19.0-23-cloud-amd64-x86_64-with-glibc2.10
- Python version: 3.8.15
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.12.0
- Transformers version: 4.26.0
- Accelerate version: 0.15.0
- xFormers version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
About this issue
- Original URL
- State: closed
- Created a year ago
- Reactions: 1
- Comments: 18 (18 by maintainers)
I looked into it a bit. This happens because when using
mixed_precision
, thevae
andtext_encoder
weights are cast tofp16
, while theunet
is still infp32
(but wrapped under aaccelerate
wrapper which usesautocast
during forward and then casts the output back tofp32
, classic mixed-precision training.) There are two cases hereunet
, it’s infp32
, without theaccelerate
autocast
wrapper, so it can’t acceptfp16
input. Hence the mismatch.unet
without unwrapping, theunet
will run fine (because of the wrapper), but the output will be casted back tofp32
and thenvae
decode will fail.I think the safest way here would be to use
torch.autocast
for inference when doing mixed-precision training. cc @patrickvonplaten @pcuencaLet’s use
torch.autocast
to solve the issue no? As explained here: https://github.com/huggingface/diffusers/issues/2568Let me do that 😃
I just tried this with train_text_to_image_lora.py, script and couldn’t reproduce, maybe check if the modifications you did to text_to_image script are similar.