diffusers: Using --mixed_precision="fp16" brings ValueError: Query/Key/Value should all have the same dtype
Describe the bug
ValueError: Query/Key/Value should all have the same dtype query.dtype: torch.float32 key.dtype : torch.float16 value.dtype: torch.float16
Reproduction
Use --mixed_precision=“fp16” in Dreambooth script
Without --mixed_precision=“fp16” it runs fine.
Logs
10/11/2023 14:43:29 - INFO - __main__ - Num examples = 5
10/11/2023 14:43:29 - INFO - __main__ - Num batches each epoch = 5
10/11/2023 14:43:29 - INFO - __main__ - Num Epochs = 134
10/11/2023 14:43:29 - INFO - __main__ - Instantaneous batch size per device = 1
10/11/2023 14:43:29 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 2
10/11/2023 14:43:29 - INFO - __main__ - Gradient Accumulation steps = 2
10/11/2023 14:43:29 - INFO - __main__ - Total optimization steps = 400
Steps: 0%| | 0/400 [00:00<?, ?it/s]Traceback (most recent call last):
File "/root/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1366, in <module>
main(args)
File "/root/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1106, in main
model_pred = unet(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 636, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 624, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
File "/root/diffusers/src/diffusers/models/unet_2d_condition.py", line 1010, in forward
sample, res_samples = downsample_block(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/root/diffusers/src/diffusers/models/unet_2d_blocks.py", line 1108, in forward
hidden_states = attn(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/root/diffusers/src/diffusers/models/transformer_2d.py", line 311, in forward
hidden_states = torch.utils.checkpoint.checkpoint(
File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 251, in checkpoint
return _checkpoint_without_reentrant(
File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 420, in _checkpoint_without_reentrant
output = function(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/root/diffusers/src/diffusers/models/attention.py", line 238, in forward
attn_output = self.attn2(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/root/diffusers/src/diffusers/models/attention_processor.py", line 423, in forward
return self.processor(
File "/root/diffusers/src/diffusers/models/attention_processor.py", line 946, in __call__
hidden_states = xformers.ops.memory_efficient_attention(
File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 197, in memory_efficient_attention
return _memory_efficient_attention(
File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 298, in _memory_efficient_attention
return _fMHA.apply(
File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 43, in forward
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 320, in _memory_efficient_attention_forward_requires_grad
inp.validate_inputs()
File "/usr/local/lib/python3.10/site-packages/xformers/ops/fmha/common.py", line 116, in validate_inputs
raise ValueError(
ValueError: Query/Key/Value should all have the same dtype
query.dtype: torch.float32
key.dtype : torch.float16
value.dtype: torch.float16
System Info
Latest diffusers from Github repo
Who can help?
No response
About this issue
- Original URL
- State: closed
- Created 9 months ago
- Comments: 16 (8 by maintainers)
I rolled the diffusers along with train_dreambooth_lora_sdxl.py back to v0.21.4 while keeping all other dependencies at latest, and this problem did not happen, so the break should be fully within the diffusers repo and probably within the past couple days. I have just used the script a couple days ago without problem. FWIW. FYI @sayakpaul
what helped me as a workaround, I’ve forced type
.to(torch.float32)
indiffusers/src/diffusers/models/attention_processor.py
around line 1134
Quick fix would be to disable xformers and rely on out default SDPA attention processor (when PT 2.0 env is detected).
I just tried it by removing the
--enable_xformers_memory_efficient_attention
flag and it worked as expected.@bluusun @odieXin @ckeisc could you give this a try?
@odieXin after some investigation, I noticed that the first
norm1
itself returnnorm_hidden_states
in torch.float32.https://github.com/huggingface/diffusers/blob/4d2c981d551566961bbc7254ae9556d76dd95764/src/diffusers/models/attention.py#L208
This is being propagated throughout the rest of the computation.
@DN6 I can add a check in the computations there (above code block mentioned) to ensure the data-types alawys match. But do you have any other insights / suggestions? Cc: @patrickvonplaten too.