diffusers: [examples] ]Error due to mismatches in the dtype of UNet while running `train_text_to_image.py`

Describe the bug

I am referring to this script: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py.

To perform qualitative validation and following https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py, I have added some sample inference code:

        if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
            logger.info(
                f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
                f" {args.validation_prompt}."
            )
            # create pipeline
            temp_unet = accelerator.unwrap_model(unet)
            pipeline = StableDiffusionPipeline.from_pretrained(
                args.pretrained_model_name_or_path,
                unet=temp_unet,
                revision=args.revision,
                torch_dtype=weight_dtype
            )
            pipeline = pipeline.to(accelerator.device)
            pipeline.set_progress_bar_config(disable=True)

            # run inference
            generator = torch.Generator(device=accelerator.device).manual_seed(
                args.seed
            )
            images = []
            for _ in range(args.num_validation_images):
                images.append(
                    pipeline(
                        args.validation_prompt,
                        num_inference_steps=30,
                        generator=generator,
                    ).images[0]
                )

Training runs smoothly but when it gets to the pipeline calling part, it runs into an error complaining about the mismatches in the dtypes of the model (UNet) and the inputs.

Cc: @kashif

Reproduction

Code: https://gist.github.com/sayakpaul/97ec0007ec423960912ca8822e4cb7be

Command:

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16"  train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=250 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --validation_prompt="cute dragon creature" \
  --seed=666 \
  --output_dir="sd-pokemon-model" 

Logs

No response

System Info

- `diffusers` version: 0.13.0.dev0
- Platform: Linux-4.19.0-23-cloud-amd64-x86_64-with-glibc2.10
- Python version: 3.8.15
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.12.0
- Transformers version: 4.26.0
- Accelerate version: 0.15.0
- xFormers version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 1
  • Comments: 18 (18 by maintainers)

Most upvoted comments

I looked into it a bit. This happens because when using mixed_precision, the vae and text_encoder weights are cast to fp16, while the unet is still in fp32 (but wrapped under a accelerate wrapper which uses autocast during forward and then casts the output back to fp32, classic mixed-precision training.) There are two cases here

  • When we unwrap the unet , it’s in fp32, without the accelerate autocast wrapper, so it can’t accept fp16 input. Hence the mismatch.
  • If we use the unet without unwrapping, the unet will run fine (because of the wrapper), but the output will be casted back to fp32 and then vae decode will fail.

I think the safest way here would be to use torch.autocast for inference when doing mixed-precision training. cc @patrickvonplaten @pcuenca

Let’s use torch.autocast to solve the issue no? As explained here: https://github.com/huggingface/diffusers/issues/2568

Let me do that 😃

I just tried this with train_text_to_image_lora.py, script and couldn’t reproduce, maybe check if the modifications you did to text_to_image script are similar.