pytorch-lightning: Missing weight params when loading deepspeed-stage2 ckpts

๐Ÿ› Bug

Converting deepspeed_zero_stage_2 to fp32 misses some parameters. I was finetuning bart-large using deepspeed_stage_2. I converted to fp32 checkpoints(using zero_to_fp32.py) and got a .bin file.

When I try to load this model using model.load_state_dict(path_to_bin_file) I get the following keys missing error. โ€œmodel.embed.tokens.weightโ€ ?

I am using huggingface transformers.

To Reproduce

Expected behavior

Environment

  • PyTorch Lightning Version :1.5.0
  • PyTorch Version : 1.10
  • Python version : 3.10
  • OS : Linux
  • CUDA/cuDNN version:
  • GPU models and configuration: V100 (16G Single)
  • How you installed PyTorch (conda, pip, source): pip
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @SeanNaren @akihironitta

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 1
  • Comments: 19 (7 by maintainers)

Most upvoted comments

@TianHongZXY I find that transformers would automatically assign these shared weights to the corresponding parameters. You can just set strict=False when loading checkpoints.

For some reason, deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint used in pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict returns a subset of the model parameters in the state_dict. I solved it by loading the checkpoint in .../checkpoint/mp_rank_00_model_states.pt (for ZeRO Stage 2) and adding the following to the LightningModule:

def on_load_checkpoint(self, checkpoint):
    state_dict = checkpoint['module']
    state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
    checkpoint['state_dict'] = state_dict