bitsandbytes: 8-bit optimizers dont work with FSDP

When I use an 8-bit ADAM with FSDP, I get an error as follows:

RuntimeError: output tensor must have the same type as input tensor

If my understanding is correct, there seems to be a casting issue. Is there any workaround this?

TIA.

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Reactions: 2
  • Comments: 20 (1 by maintainers)

Most upvoted comments

I encountered a similar issue using PEFT LoRA, load_in_8bit, and DeepSpeed 3 (optimizer and params offload) with huggingface accelerator. on a single gpu, training was fine as expected.

If anyone found a workaround to enable parallel training with PEFT LoRA and load_in_8bit, please let me know.

@Titus-von-Koeller On one hand, we can workaround this by loading all the quantities onto GPU, but this will be very inefficient. On the other hand, I feel the better approach would be to run the optimizer step alongside the FSDP sharding.

As we see here, the optimizer step can be run after the FSDP post grad hook. There is a comment there to say that for CPU offload the parameters and gradients are run on CPU, but this should not be the case. If during offload, we can run the optimizer step in GPU before it gets offloaded, then this solves our problem and we do not need to shuffle params around

I have posted a comment on pytorch asking when FSDP will start to support running optim.step on the GPU. I will keep you updated when I get a response.

Noting that this issue, although stale, remains an issue. Although optimization can run, a functional state dict cannot be saved with 8bitadam.

I notice that there is a PR for FSDP functionality in https://github.com/TimDettmers/bitsandbytes/pull/840. It generally does not address the state dict issue in its tests.

? I do not test via huggingface.

I was in fact trying to only use an 8bit optimiser with 32bit weights, though, so I do not experience the int8 flatparameter issue you do.

Hi @TimDettmers in my latest test, it turns out that saving the model is the source of this issue.

Specifically the error pops up when I run this: optim_state = FSDP.full_optim_state_dict(model, optimizer)

What this is supposed to do is assemble the entire optimizer based on the model params. Now what I think is the problem is that the optimizer is in 8-bit but the model is not in 8-bit. The reason for my assumption is the error is thrown by

File “/share03/draj/environments/.conda/envs/yanmtt/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py”, line 2136, in _all_gather_base work = group._allgather_base(output_tensor, input_tensor)

Indeed if you look here: https://github.com/pytorch/pytorch/blob/55daa835e97a6e742cba1f0e9d2a5c78b1615e99/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L2779

Then there is a constraint that the dtypes of tensors should be the same and we are not able to guarantee this for a sharded 8-bit optimizer.

If we can find some way to bypass this requirement, then we are good to go.

How do we overcome this issue?

@Titus-von-Koeller @TimDettmers I think the problem still remains even with BNB 0.43. The reason is because BNB performs optimizer steps with CUDA.

  1. when using CPU offload, the gradients are put onto the CPU
  2. However before the BNB 8bit optimizer step, there is a pre_call to put all of the tensors onto the same GPU
 prev_device = pre_call(g.device)
  1. However since the gradient g is on cpu, it is obvious why pre_call will fail, since now device="cpu" below:
 def pre_call(device):
    prev_device = torch.cuda.current_device()
    torch.cuda.set_device(device)
    return prev_device
  1. And finally, all of the optimizer quantities in the is_on_gpu call are on the cpu
 is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])

Thus while one can move all of the above quanities to gpu -> compute -> cpu. Im not sure if this is the most optimal way to do things as it will involve a lot of IO overhead.

@Titus-von-Koeller @TimDettmers sorry to hijack this issue. Doing something related but not exactly the same.

im trying to use FSDP with bitsandbytes==0.42.0 to finetune EleutherAI/pythia-1b that has 8bit weights

  • AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1b", load_in_8bit=True)
  • added lora adapters, and i have different FSDP wrappers for anything that is not bnb.Linear8bitLt
     GPTNeoXLayer(
        (input_layernorm): FullyShardedDataParallel(
          (_fsdp_wrapped_module): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
        (post_attention_layernorm): FullyShardedDataParallel(
          (_fsdp_wrapped_module): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        )
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)                                                                                                                                                                                                (attention): GPTNeoXAttention(
          (rotary_emb): FullyShardedDataParallel(
            (_fsdp_wrapped_module): GPTNeoXRotaryEmbedding()
          )
          (query_key_value): lora.Linear8bitLt(
            (base_layer): Linear8bitLt(in_features=2048, out_features=6144, bias=True)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): FullyShardedDataParallel(
                (_fsdp_wrapped_module): Linear(in_features=2048, out_features=8, bias=False)
              )
            )
            (lora_B): ModuleDict(
              (default): FullyShardedDataParallel(
                (_fsdp_wrapped_module): Linear(in_features=8, out_features=6144, bias=False)
              )
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
          )
          (dense): Linear8bitLt(in_features=2048, out_features=2048, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear8bitLt(in_features=2048, out_features=8192, bias=True)
          (dense_4h_to_h): Linear8bitLt(in_features=8192, out_features=2048, bias=True)
          (act): GELUActivation()
        )
      )
      ```
    
    

The FSDP wrapping will fail at _validate_tensors_to_flatten when it tries to flatten Linear8bitLt for sharding. This is because Linear8bitLt.dtype is torch.int8, and _validate_tensors_to_flatten requires that it be a floating point type.

anyone still working on this…?

on the error @prajdabre was mentioning, I find that the problem does not come from a dtype mismatch, but rather a size mismatch. With printf debugging, I noticed that this seemed to first error on the absmax1 value, with

output_tensor.shape == Size([361496576]), output_tensor.dtype == float32
input_tensor.shape == Size([22064]), input_tensor.dtype == float32