bitsandbytes: 8-bit optimizers dont work with FSDP
When I use an 8-bit ADAM with FSDP, I get an error as follows:
RuntimeError: output tensor must have the same type as input tensor
If my understanding is correct, there seems to be a casting issue. Is there any workaround this?
TIA.
About this issue
- Original URL
- State: open
- Created 2 years ago
- Reactions: 2
- Comments: 20 (1 by maintainers)
I encountered a similar issue using PEFT LoRA, load_in_8bit, and DeepSpeed 3 (optimizer and params offload) with huggingface accelerator. on a single gpu, training was fine as expected.
If anyone found a workaround to enable parallel training with PEFT LoRA and load_in_8bit, please let me know.
@Titus-von-Koeller On one hand, we can workaround this by loading all the quantities onto GPU, but this will be very inefficient. On the other hand, I feel the better approach would be to run the optimizer step alongside the FSDP sharding.
As we see here, the optimizer step can be run after the FSDP post grad hook. There is a comment there to say that for CPU offload the parameters and gradients are run on CPU, but this should not be the case. If during offload, we can run the optimizer step in GPU before it gets offloaded, then this solves our problem and we do not need to shuffle params around
I have posted a comment on pytorch asking when FSDP will start to support running
optim.stepon the GPU. I will keep you updated when I get a response.Noting that this issue, although stale, remains an issue. Although optimization can run, a functional state dict cannot be saved with 8bitadam.
I notice that there is a PR for FSDP functionality in https://github.com/TimDettmers/bitsandbytes/pull/840. It generally does not address the state dict issue in its tests.
? I do not test via huggingface.
I was in fact trying to only use an 8bit optimiser with 32bit weights, though, so I do not experience the int8 flatparameter issue you do.
Hi @TimDettmers in my latest test, it turns out that saving the model is the source of this issue.
Specifically the error pops up when I run this: optim_state = FSDP.full_optim_state_dict(model, optimizer)
What this is supposed to do is assemble the entire optimizer based on the model params. Now what I think is the problem is that the optimizer is in 8-bit but the model is not in 8-bit. The reason for my assumption is the error is thrown by
File “/share03/draj/environments/.conda/envs/yanmtt/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py”, line 2136, in _all_gather_base work = group._allgather_base(output_tensor, input_tensor)
Indeed if you look here: https://github.com/pytorch/pytorch/blob/55daa835e97a6e742cba1f0e9d2a5c78b1615e99/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L2779
Then there is a constraint that the dtypes of tensors should be the same and we are not able to guarantee this for a sharded 8-bit optimizer.
If we can find some way to bypass this requirement, then we are good to go.
How do we overcome this issue?
@Titus-von-Koeller @TimDettmers I think the problem still remains even with BNB 0.43. The reason is because BNB performs optimizer steps with CUDA.
pre_callto put all of the tensors onto the same GPUgis oncpu, it is obvious whypre_callwill fail, since nowdevice="cpu"below:is_on_gpucall are on thecpuThus while one can move all of the above quanities to gpu -> compute -> cpu. Im not sure if this is the most optimal way to do things as it will involve a lot of IO overhead.
@Titus-von-Koeller @TimDettmers sorry to hijack this issue. Doing something related but not exactly the same.
im trying to use FSDP with
bitsandbytes==0.42.0to finetuneEleutherAI/pythia-1bthat has 8bit weightsAutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1b", load_in_8bit=True)bnb.Linear8bitLtThe FSDP wrapping will fail at
_validate_tensors_to_flattenwhen it tries to flattenLinear8bitLtfor sharding. This is becauseLinear8bitLt.dtypeistorch.int8, and_validate_tensors_to_flattenrequires that it be a floating point type.anyone still working on this…?
on the error @prajdabre was mentioning, I find that the problem does not come from a dtype mismatch, but rather a size mismatch. With printf debugging, I noticed that this seemed to first error on the absmax1 value, with