pytorch-lightning: Missing weight params when loading deepspeed-stage2 ckpts
๐ Bug
Converting deepspeed_zero_stage_2 to fp32 misses some parameters. I was finetuning bart-large using deepspeed_stage_2. I converted to fp32 checkpoints(using zero_to_fp32.py) and got a .bin file.
When I try to load this model using model.load_state_dict(path_to_bin_file) I get the following keys missing error. โmodel.embed.tokens.weightโ ?
I am using huggingface transformers.
To Reproduce
Expected behavior
Environment
- PyTorch Lightning Version :1.5.0
- PyTorch Version : 1.10
- Python version : 3.10
- OS : Linux
- CUDA/cuDNN version:
- GPU models and configuration: V100 (16G Single)
- How you installed PyTorch (
conda,pip, source): pip - If compiling from source, the output of
torch.__config__.show(): - Any other relevant information:
Additional context
cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj @SeanNaren @akihironitta
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 1
- Comments: 19 (7 by maintainers)
@TianHongZXY I find that
transformerswould automatically assign these shared weights to the corresponding parameters. You can just setstrict=Falsewhen loading checkpoints.For some reason,
deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpointused inpytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dictreturns a subset of the model parameters in the state_dict. I solved it by loading the checkpoint in.../checkpoint/mp_rank_00_model_states.pt(for ZeRO Stage 2) and adding the following to theLightningModule: