diffusers: Using --mixed_precision="fp16" brings ValueError: Query/Key/Value should all have the same dtype

Describe the bug

ValueError: Query/Key/Value should all have the same dtype query.dtype: torch.float32 key.dtype : torch.float16 value.dtype: torch.float16

Reproduction

Use --mixed_precision=“fp16” in Dreambooth script

Without --mixed_precision=“fp16” it runs fine.

Logs

10/11/2023 14:43:29 - INFO - __main__ -   Num examples = 5
10/11/2023 14:43:29 - INFO - __main__ -   Num batches each epoch = 5
10/11/2023 14:43:29 - INFO - __main__ -   Num Epochs = 134
10/11/2023 14:43:29 - INFO - __main__ -   Instantaneous batch size per device = 1
10/11/2023 14:43:29 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 2
10/11/2023 14:43:29 - INFO - __main__ -   Gradient Accumulation steps = 2
10/11/2023 14:43:29 - INFO - __main__ -   Total optimization steps = 400

Steps:   0%|          | 0/400 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/root/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1366, in <module>
    main(args)
  File "/root/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1106, in main
    model_pred = unet(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 636, in forward
    return model_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 624, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/root/diffusers/src/diffusers/models/unet_2d_condition.py", line 1010, in forward
    sample, res_samples = downsample_block(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/diffusers/src/diffusers/models/unet_2d_blocks.py", line 1108, in forward
    hidden_states = attn(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/diffusers/src/diffusers/models/transformer_2d.py", line 311, in forward
    hidden_states = torch.utils.checkpoint.checkpoint(
  File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 251, in checkpoint
    return _checkpoint_without_reentrant(
  File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 420, in _checkpoint_without_reentrant
    output = function(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/diffusers/src/diffusers/models/attention.py", line 238, in forward
    attn_output = self.attn2(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/diffusers/src/diffusers/models/attention_processor.py", line 423, in forward
    return self.processor(
  File "/root/diffusers/src/diffusers/models/attention_processor.py", line 946, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
  File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 197, in memory_efficient_attention
    return _memory_efficient_attention(
  File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 298, in _memory_efficient_attention
    return _fMHA.apply(
  File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 43, in forward
    out, op_ctx = _memory_efficient_attention_forward_requires_grad(
  File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 320, in _memory_efficient_attention_forward_requires_grad
    inp.validate_inputs()
  File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/common.py", line 116, in validate_inputs
    raise ValueError(
ValueError: Query/Key/Value should all have the same dtype
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16

System Info

Latest diffusers from Github repo

Who can help?

No response

About this issue

  • Original URL
  • State: closed
  • Created 9 months ago
  • Comments: 16 (8 by maintainers)

Commits related to this issue

Most upvoted comments

I rolled the diffusers along with train_dreambooth_lora_sdxl.py back to v0.21.4 while keeping all other dependencies at latest, and this problem did not happen, so the break should be fully within the diffusers repo and probably within the past couple days. I have just used the script a couple days ago without problem. FWIW. FYI @sayakpaul

what helped me as a workaround, I’ve forced type .to(torch.float32) in diffusers/src/diffusers/models/attention_processor.py

around line 1134

        query = attn.head_to_batch_dim(query).contiguous().to(torch.float32)
        key = attn.head_to_batch_dim(key).contiguous().to(torch.float32)
        value = attn.head_to_batch_dim(value).contiguous().to(torch.float32)

Quick fix would be to disable xformers and rely on out default SDPA attention processor (when PT 2.0 env is detected).

I just tried it by removing the --enable_xformers_memory_efficient_attention flag and it worked as expected.

@bluusun @odieXin @ckeisc could you give this a try?

@odieXin after some investigation, I noticed that the first norm1 itself return norm_hidden_states in torch.float32.

https://github.com/huggingface/diffusers/blob/4d2c981d551566961bbc7254ae9556d76dd95764/src/diffusers/models/attention.py#L208

This is being propagated throughout the rest of the computation.

@DN6 I can add a check in the computations there (above code block mentioned) to ensure the data-types alawys match. But do you have any other insights / suggestions? Cc: @patrickvonplaten too.