transformers: Trainer/accelerate crashes when loading checkpoint using FSDP: sync_module_states ValueError

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.15.0-41-generic-x86_64-with-glibc2.31
  • Python version: 3.11.4
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

Who can help?

@sgugger @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

  1. Use run_clm.py (https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py) to train a large model using the HuggingFace Trainer, use FSDP and save checkpoints. For example:
torchrun --nproc_per_node=4 --master_port=XXXXX experiments/run_clm.py \
    --model_name_or_path meta-llama/Llama-2-7b-hf \
    --dataset_name openwebtext \
    --streaming \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --do_train \
    --max_steps 1000 \
    --output_dir output_dir/ \
    --block_size 512 \
    --save_steps 10 \
    --save_total_limit 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap "LlamaDecoderLayer" \
    --tf32 True \
    --bf16 True \
    --gradient_checkpointing \
  1. Kill training after a checkpoint has been saved. Then, resume training from the checkpoint with the resume_from_checkpoint training argument.
  2. Observed behavior: crashes when loading checkpoint model:
Traceback (most recent call last):                                                                                                                                            
  File ".../run_clm.py", line 638, in <module>                                                                                     
    main()                                                                                                                                                                    
  File ".../run_clm.py", line 584, in main                                                                                         
    main()                                                                                                                                                                    
  File ".../run_clm.py", line 584, in main                                                                                         
    train_result = trainer.train(resume_from_checkpoint=checkpoint)                                                                                                           
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                           
  File "miniconda3.../lib/python3.11/site-packages/transformers/trainer.py", line 1528, in train                                                 
    main()                                                                                                                                                                    
    train_result = trainer.train(resume_from_checkpoint=checkpoint)                                                                                                           
  File ".../run_clm.py", line 584, in main                                                                                         
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                           
  File "miniconda3.../lib/python3.11/site-packages/transformers/trainer.py", line 1528, in train                                                 
    train_result = trainer.train(resume_from_checkpoint=checkpoint)                                                                                                           
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                           
  File "miniconda3.../lib/python3.11/site-packages/transformers/trainer.py", line 1528, in train                                                 
    self._load_from_checkpoint(resume_from_checkpoint)                                                                                                                        
  File "miniconda3.../lib/python3.11/site-packages/transformers/trainer.py", line 2055, in _load_from_checkpoint                                 
    self._load_from_checkpoint(resume_from_checkpoint)                                                                                                                        
      File "miniconda3.../lib/python3.11/site-packages/transformers/trainer.py", line 2055, in _load_from_checkpoint                             
self._load_from_checkpoint(resume_from_checkpoint)                                                                                                                            
  File "miniconda3.../lib/python3.11/site-packages/transformers/trainer.py", line 2055, in _load_from_checkpoint                                 
    load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)                                                                      
  File "miniconda3.../lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 79, in load_fsdp_model
    load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)                                                                      
  File "miniconda3.../lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 79, in load_fsdp_model
    load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)                                                                      
  File "miniconda3.../lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 79, in load_fsdp_model
    raise ValueError(
ValueError: Set the `sync_module_states` flag to `True` so that model states are synced across processes when initializing FSDP object
    raise ValueError(
ValueError: Set the `sync_module_states` flag to `True` so that model states are synced across processes when initializing FSDP object
    raise ValueError(
ValueError: Set the `sync_module_states` flag to `True` so that model states are synced across processes when initializing FSDP object
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 45997 closing signal SIGTERM                                                                            
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 1 (pid: 45998) of binary: miniconda3.../bin/python

Expected behavior

Expected behavior: can resume training from checkpoint using FSDP.

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 27 (12 by maintainers)

Most upvoted comments

@pacman100 I’m also running into a similar error with the latest main branch:

File "/home/hyen/.conda/envs/cross/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 608, in set_state_dict_type
    state_dict_config_type = _state_dict_type_to_config[state_dict_type]
KeyError: None