ColossalAI: [BUG]: Fail to load huggingface pretraining when use shardinit

🐛 Describe the bug

        world_size = torch.distributed.get_world_size()
        shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
        default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None

        with ColoInitContext(device=get_current_device(),
                             dtype=torch.half,
                             default_dist_spec=default_dist_spec,
                             default_pg=shard_pg):
            model = BloomForCausalLM.from_pretrained(args.model_name_or_path)

When using shardinit, the model will be split into multiple GPUs first and then load the huggingface pertaining, so checkpoint mismatch will occur.

RuntimeError: Error(s) in loading state_dict for BloomForCausalLM:
        size mismatch for transformer.word_embeddings.weight: copying a param 
with shape torch.Size([46145, 4096]) from checkpoint, the shape in current model
is torch.Size([46145, 512]).
        size mismatch for transformer.word_embeddings_layernorm.weight: copying 
a param with shape torch.Size([4096]) from checkpoint, the shape in current 
model is torch.Size([512]).
        size mismatch for transformer.word_embeddings_layernorm.bias: copying a 
param with shape torch.Size([4096]) from checkpoint, the shape in current model 
is torch.Size([512]).
        size mismatch for transformer.h.0.input_layernorm.weight: copying a 
param with shape torch.Size([4096]) from checkpoint, the shape in current model 
is torch.Size([512]).

I wonder to know how to successfully load huggingface pertaining when using shardinit, seems that it’s necessary when we want to fine-tune a very large model.

Environment

No response

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 17 (3 by maintainers)

Most upvoted comments

@ShinoharaHare Thx for the reply. I’ll give it a try. By the way, did you compare the accelerate performance between this strategy and Megatron?

Nope, didn’t test.

hi @ShinoharaHare ,I come across the same error. Thanks for your solution ,I will have a try. but I still have a question here “ x = state_dict[n]” – does it mean Deserialized the huggingface model into a state_dict like : state_dict = torch.load(“xxx”) (to cpu maybe) ? before the ColoInitContext process?

Yes, that’s correct.

And there is a better but cumbersome way:

  1. Convert the pretrained weights into safetensors
  2. Use LazyInitContext with ColoInitContext to construct the model faster (might need to pass to_meta=False)
  3. To load only the required parts on each rank, you can utilize the get_slice API from safetensors. This way, the weights should be able to be loaded directly into GPUs without first being loaded into the CPU.

A workaround is to construct the model first and then load the weights manually.

with ColoInitContext(
    device=get_current_device(),
    dtype=torch.half,
    default_pg=ProcessGroup(tp_degree=world_size),
    default_dist_spec=ShardSpec([-1], [world_size]),
):
    model = BloomForCausalLM(BloomConfig.from_pretrained(pretrained_path))
    for n, p in model.named_parameters():
        x = state_dict[n]
        x = x.chunk(world_size, dim=-1)
        x = x[global_rank]
        p.data.copy_(x)