bitsandbytes: RuntimeError: expected scalar type Half but found Float

I met the following error when I tried to train bloom-7b1-mt with peft LoRA in 8bit+fp16 (torch amp) mode:

Traceback (most recent call last):
  File "finetune.py", line 141, in <module>
    train(args)
  File "finetune.py", line 133, in train
    trainer.train()
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/transformers/trainer.py", line 1638, in train
    ignore_keys_for_eval=ignore_keys_for_eval,
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/transformers/trainer.py", line 1903, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/transformers/trainer.py", line 2660, in training_step
    tmp.backward()
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/torch/_tensor.py", line 489, in backward
    self, gradient, retain_graph, create_graph, inputs=inputs
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/torch/autograd/__init__.py", line 199, in backward
    allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/torch/autograd/__init__.py", line 199, in backward
    allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/home/yyq/.conda/envs/transformers/lib/python3.7/site-packages/bitsandbytes/autograd/_functions.py", line 456, in backward
    grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
RuntimeError: expected scalar type Half but found Float

which does not appear during training llama-7b with exactly the same settings.

Also, it does not appear if I set fp16=False.

model = AutoModel.from_pretrained(
    BLOOM_MODEL_PATH,
    trust_remote_code=True,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map='auto',
)
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, lora_config)
...
trainer = transformers.Trainer(
    args=transformers.TrainingArguments(
        ...
        fp16=False,
        ...
    ),
    ...
)

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Reactions: 4
  • Comments: 26

Most upvoted comments

Hi!

I’ve solved by adding:

with torch.autocast("cuda"): 
    trainer.train()

This error occurs when the 2 matrices you are multiplying are not of same dtype.

Half means dtype = torch.float16 while, Float means dtype = torch.float32

to resolve the error simply cast your model weights into float16

for param in model.parameters():
    # Check if parameter dtype is  Float (float32)
    if param.dtype == torch.float32:
        param.data = param.data.to(torch.float16)

solved

This error occurs when the 2 matrices you are multiplying are not of same dtype.

Half means dtype = torch.float16 while, Float means dtype = torch.float32

to resolve the error simply cast your model weights into float16

for param in model.parameters():
    # Check if parameter dtype is  Float (float32)
    if param.dtype == torch.float32:
        param.data = param.data.to(torch.float16)

I am using TRL’s SFTTrainer to train OPT (loaded in int8 + LoRA), and seeing the same issue. While

with torch.autocast("cuda"): 
    trainer.train()

Solves the precision mismatch issue, it causes the loss to become 0.

I was using peft with whisper model and got this issue.

with torch.autocast("cuda"): 
    trainer.train()

solved the issue.

I guess, AMP uses lower precision (like float16 also known as half precision) for certain operations which are less sensitive to precision, while keeping others (like the loss calculation) in higher precision (like float32). torch.autocast context manager will enable AMP for the operations inside its block.

for param in model.parameters(): # Check if parameter dtype is Float (float32) if param.dtype == torch.float32: param.data = param.data.to(torch.float16)

Changing the param’s datatype manually helped me. Training has started 😃