ColossalAI: [BUG]: Fail to load huggingface pretraining when use shardinit
🐛 Describe the bug
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = BloomForCausalLM.from_pretrained(args.model_name_or_path)
When using shardinit, the model will be split into multiple GPUs first and then load the huggingface pertaining, so checkpoint mismatch will occur.
RuntimeError: Error(s) in loading state_dict for BloomForCausalLM:
size mismatch for transformer.word_embeddings.weight: copying a param
with shape torch.Size([46145, 4096]) from checkpoint, the shape in current model
is torch.Size([46145, 512]).
size mismatch for transformer.word_embeddings_layernorm.weight: copying
a param with shape torch.Size([4096]) from checkpoint, the shape in current
model is torch.Size([512]).
size mismatch for transformer.word_embeddings_layernorm.bias: copying a
param with shape torch.Size([4096]) from checkpoint, the shape in current model
is torch.Size([512]).
size mismatch for transformer.h.0.input_layernorm.weight: copying a
param with shape torch.Size([4096]) from checkpoint, the shape in current model
is torch.Size([512]).
I wonder to know how to successfully load huggingface pertaining when using shardinit, seems that it’s necessary when we want to fine-tune a very large model.
Environment
No response
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 17 (3 by maintainers)
Nope, didn’t test.
Yes, that’s correct.
And there is a better but cumbersome way:
LazyInitContextwithColoInitContextto construct the model faster (might need to passto_meta=False)get_sliceAPI from safetensors. This way, the weights should be able to be loaded directly into GPUs without first being loaded into the CPU.A workaround is to construct the model first and then load the weights manually.