transformers: Bug in trainer: substantially different results from restarting from a checkpoint and without

Environment info

  • transformers version: 4.5.5
  • Platform: linux
  • Python version: 3.7
  • PyTorch version (GPU?): 1.8
  • Tensorflow version (GPU?): -
  • Using GPU in script?: -
  • Using distributed or parallel set-up in script?: -

Who can help

@sgugger @patrickvonplaten, @patil-suraj

Information

  • I am training T5 model and I am resuming the training from a checkpoint
  • I have fixed the issue here https://github.com/huggingface/transformers/issues/11294 by freezing the parameters back right after this loads the model from the checkpoint
  • I am using “evaluation_strategy”: “steps” and I evaluate the model every 10 steps with “save_total_limit”: 1
  • I modified the save_checkpoint class as below to “save last copy of the model in output_dir” as one need to load a checkpoint from the place the model is left trained, and not from the checkpoint with best evaluation:
    def _save_checkpoint(self, model, trial, metrics=None):
        super()._save_checkpoint(model, trial, metrics)
        # Saves the models checkpoints in the main folder.
        if self.is_world_process_zero():
            # remove the older global_steps.
            global_steps = [str(x) for x in Path(self.args.output_dir).glob("global_step*")]
            for global_step in global_steps:
                shutil.rmtree(global_step)
            self.save_model(self.args.output_dir)
            if self.deepspeed:
                self.deepspeed.save_checkpoint(self.args.output_dir)
            else:
                # deepspeed.save_checkpoint above saves model/optim/sched
                torch.save(self.optimizer.state_dict(), os.path.join(self.args.output_dir, "optimizer.pt"))
                with warnings.catch_warnings(record=True) as caught_warnings:
                    torch.save(self.lr_scheduler.state_dict(), os.path.join(self.args.output_dir, "scheduler.pt"))
                reissue_pt_warnings(caught_warnings)
            self.state.save_to_json(os.path.join(self.args.output_dir, "trainer_state.json"))

then I find the last checkpoint to resume from it from the saved one in output directory as below:

def get_last_checkpoint(output_dir):
    if os.path.exists(os.path.join(output_dir, 'pytorch_model.bin')):
        return output_dir
    return None

Here is the results without resume for 10 times evaluation:

{'loss': 5.0483, 'learning_rate': 6e-07, 'epoch': 0.02}                                                                                                                                       
  0%|                                                                                                                                                   | 10/60000 [00:07<11:11:04,  1.49it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.54it/s]
{'mrpc_en_eval_loss': 5.382528305053711, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.741, 'epoch': 0.22}                                                                                                                                                                      
{'mrpc_en_eval_loss': 5.382528305053711, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.741, 'epoch': 0.22, 'eval_average_metrics': 0.0}                                                                                                                                         
  0%|                                                                                                                                                   | 20/60000 [00:20<11:57:29,  1.39it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.56it/s]
{'mrpc_en_eval_loss': 5.180729389190674, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8179, 'mrpc_en_eval_samples_per_second': 112.218, 'epoch': 0.43}                                                                                                                                                                      
{'mrpc_en_eval_loss': 5.180729389190674, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8179, 'mrpc_en_eval_samples_per_second': 112.218, 'epoch': 0.43, 'eval_average_metrics': 0.0}                                                                                                                                         
  0%|                                                                                                                                                   | 30/60000 [00:33<12:01:13,  1.39it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.52it/s]
{'mrpc_en_eval_loss': 4.810805320739746, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.743, 'epoch': 0.65}                                                                                                                                                                      
{'mrpc_en_eval_loss': 4.810805320739746, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.743, 'epoch': 0.65, 'eval_average_metrics': 0.0}                                                                                                                                         
  0%|                                                                                                                                                   | 40/60000 [00:45<11:17:50,  1.47it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.54it/s]
{'mrpc_en_eval_loss': 4.203256607055664, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.031, 'mrpc_en_eval_samples_per_second': 100.441, 'epoch': 0.87}                                                                                                                                                        
{'mrpc_en_eval_loss': 4.203256607055664, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.031, 'mrpc_en_eval_samples_per_second': 100.441, 'epoch': 0.87, 'eval_average_metrics': 0.0}                                                                                                                           
  0%|                                                                                                                                                   | 50/60000 [00:58<11:42:57,  1.42it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.39it/s]
{'mrpc_en_eval_loss': 3.262455463409424, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.1069, 'mrpc_en_eval_samples_per_second': 96.825, 'epoch': 1.09}                                                                                                                                                        
{'mrpc_en_eval_loss': 3.262455463409424, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.1069, 'mrpc_en_eval_samples_per_second': 96.825, 'epoch': 1.09, 'eval_average_metrics': 0.0}                                                                                                                           
  0%|▏                                                                                                                                                  | 60/60000 [01:13<11:57:15,  1.39it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.78it/s]
{'mrpc_en_eval_loss': 1.9655567407608032, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.053921568627451, 'mrpc_en_eval_runtime': 2.8657, 'mrpc_en_eval_samples_per_second': 71.186, 'epoch': 1.3}                                                                                                                                         
{'mrpc_en_eval_loss': 1.9655567407608032, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.053921568627451, 'mrpc_en_eval_runtime': 2.8657, 'mrpc_en_eval_samples_per_second': 71.186, 'epoch': 1.3, 'eval_average_metrics': 0.24509803921568626}                                                                                            
  0%|▏                                                                                                                                                  | 70/60000 [01:27<12:14:11,  1.36it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.08it/s]
{'mrpc_en_eval_loss': 0.7519775032997131, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.9411764705882355, 'mrpc_en_eval_runtime': 2.6193, 'mrpc_en_eval_samples_per_second': 77.884, 'epoch': 1.52}                                                                                                                          
{'mrpc_en_eval_loss': 0.7519775032997131, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.9411764705882355, 'mrpc_en_eval_runtime': 2.6193, 'mrpc_en_eval_samples_per_second': 77.884, 'epoch': 1.52, 'eval_average_metrics': 26.60441477204379}                                                                               
  0%|▏                                                                                                                                                  | 80/60000 [01:41<12:02:22,  1.38it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.60it/s]
{'mrpc_en_eval_loss': 0.4142318665981293, 'mrpc_en_eval_f1': 75.62500000000001, 'mrpc_en_eval_accuracy': 61.76470588235294, 'mrpc_en_eval_gen_len': 2.1176470588235294, 'mrpc_en_eval_runtime': 1.7878, 'mrpc_en_eval_samples_per_second': 114.109, 'epoch': 1.74}                                                                                                                          
{'mrpc_en_eval_loss': 0.4142318665981293, 'mrpc_en_eval_f1': 75.62500000000001, 'mrpc_en_eval_accuracy': 61.76470588235294, 'mrpc_en_eval_gen_len': 2.1176470588235294, 'mrpc_en_eval_runtime': 1.7878, 'mrpc_en_eval_samples_per_second': 114.109, 'epoch': 1.74, 'eval_average_metrics': 68.69485294117648}                                                                               
  0%|▏                                                                                                                                                  | 90/60000 [01:54<11:41:23,  1.42it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.54it/s]
{'mrpc_en_eval_loss': 0.3786551058292389, 'mrpc_en_eval_f1': 51.18483412322274, 'mrpc_en_eval_accuracy': 49.50980392156863, 'mrpc_en_eval_gen_len': 2.6519607843137254, 'mrpc_en_eval_runtime': 1.8265, 'mrpc_en_eval_samples_per_second': 111.69, 'epoch': 1.96}                                                                                                                           
{'mrpc_en_eval_loss': 0.3786551058292389, 'mrpc_en_eval_f1': 51.18483412322274, 'mrpc_en_eval_accuracy': 49.50980392156863, 'mrpc_en_eval_gen_len': 2.6519607843137254, 'mrpc_en_eval_runtime': 1.8265, 'mrpc_en_eval_samples_per_second': 111.69, 'epoch': 1.96, 'eval_average_metrics': 50.34731902239569}                                                                                
  0%|▏                                                                                                                                                 | 100/60000 [02:07<12:01:27,  1.38it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.58it/s]
{'mrpc_en_eval_loss': 0.29472649097442627, 'mrpc_en_eval_f1': 71.01449275362319, 'mrpc_en_eval_accuracy': 60.78431372549019, 'mrpc_en_eval_gen_len': 2.3333333333333335, 'mrpc_en_eval_runtime': 1.812, 'mrpc_en_eval_samples_per_second': 112.581, 'epoch': 2.17}                                                                                                                          
{'mrpc_en_eval_loss': 0.29472649097442627, 'mrpc_en_eval_f1': 71.01449275362319, 'mrpc_en_eval_accuracy': 60.78431372549019, 'mrpc_en_eval_gen_len': 2.3333333333333335, 'mrpc_en_eval_runtime': 1.812, 'mrpc_en_eval_samples_per_second': 112.581, 'epoch': 2.17, 'eval_average_metrics': 65.89940323955669}                                                                               
  

Now lets resume from step = 40, while the first 40 steps would get the same results, after resuming the results differ a lot:

0%|                                                                                                                                                    | 40/60000 [00:07<9:49:41,  1.69it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.62it/s]
{'mrpc_en_eval_loss': 4.203643321990967, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.0033, 'mrpc_en_eval_samples_per_second': 101.834, 'epoch': 0.87}                                                                                                                                                       
{'mrpc_en_eval_loss': 4.203643321990967, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.0033, 'mrpc_en_eval_samples_per_second': 101.834, 'epoch': 0.87, 'eval_average_metrics': 0.0}                                                                                                                          
  0%|                                                                                                                                                   | 50/60000 [00:21<12:09:50,  1.37it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.30it/s]
{'mrpc_en_eval_loss': 3.2706634998321533, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.2048, 'mrpc_en_eval_samples_per_second': 92.524, 'epoch': 1.09}                                                                                                                                                       
{'mrpc_en_eval_loss': 3.2706634998321533, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.2048, 'mrpc_en_eval_samples_per_second': 92.524, 'epoch': 1.09, 'eval_average_metrics': 0.0}                                                                                                                          
  0%|▏                                                                                                                                                  | 60/60000 [00:35<12:27:28,  1.34it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.54it/s]
{'mrpc_en_eval_loss': 1.9863247871398926, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.019607843137255, 'mrpc_en_eval_runtime': 2.4126, 'mrpc_en_eval_samples_per_second': 84.557, 'epoch': 1.3}                                                                                                                                         
{'mrpc_en_eval_loss': 1.9863247871398926, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.019607843137255, 'mrpc_en_eval_runtime': 2.4126, 'mrpc_en_eval_samples_per_second': 84.557, 'epoch': 1.3, 'eval_average_metrics': 0.24509803921568626}                                                                                            
  0%|▏                                                                                                                                                  | 70/60000 [00:49<12:02:36,  1.38it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.07it/s]
{'mrpc_en_eval_loss': 0.7721647620201111, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.946078431372549, 'mrpc_en_eval_runtime': 2.5655, 'mrpc_en_eval_samples_per_second': 79.518, 'epoch': 1.52}                                                                                                                           
{'mrpc_en_eval_loss': 0.7721647620201111, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.946078431372549, 'mrpc_en_eval_runtime': 2.5655, 'mrpc_en_eval_samples_per_second': 79.518, 'epoch': 1.52, 'eval_average_metrics': 26.60441477204379}                                                                                
  0%|▏                                                                                                                                                  | 80/60000 [01:02<12:08:06,  1.37it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.55it/s]
{'mrpc_en_eval_loss': 0.42692506313323975, 'mrpc_en_eval_f1': 74.28571428571428, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.142156862745098, 'mrpc_en_eval_runtime': 1.8243, 'mrpc_en_eval_samples_per_second': 111.824, 'epoch': 1.74}                                                                                                                          
{'mrpc_en_eval_loss': 0.42692506313323975, 'mrpc_en_eval_f1': 74.28571428571428, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.142156862745098, 'mrpc_en_eval_runtime': 1.8243, 'mrpc_en_eval_samples_per_second': 111.824, 'epoch': 1.74, 'eval_average_metrics': 67.28991596638654}                                                                               
  0%|▏                                                                                                                                                  | 90/60000 [01:16<12:00:53,  1.39it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.50it/s]
{'mrpc_en_eval_loss': 0.39015302062034607, 'mrpc_en_eval_f1': 45.685279187817265, 'mrpc_en_eval_accuracy': 47.549019607843135, 'mrpc_en_eval_gen_len': 2.7205882352941178, 'mrpc_en_eval_runtime': 1.856, 'mrpc_en_eval_samples_per_second': 109.915, 'epoch': 1.96}                                                                                                                        
{'mrpc_en_eval_loss': 0.39015302062034607, 'mrpc_en_eval_f1': 45.685279187817265, 'mrpc_en_eval_accuracy': 47.549019607843135, 'mrpc_en_eval_gen_len': 2.7205882352941178, 'mrpc_en_eval_runtime': 1.856, 'mrpc_en_eval_samples_per_second': 109.915, 'epoch': 1.96, 'eval_average_metrics': 46.617149397830204}                                                                            
  0%|▏                                                                                                                                                 | 100/60000 [01:31<12:02:17,  1.38it/s]***** Running Evaluation *****
  Num examples = 204
  Batch size = 80
                                                                                                                                                                                             ### n_samples  204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.55it/s]
{'mrpc_en_eval_loss': 0.30966323614120483, 'mrpc_en_eval_f1': 68.48249027237354, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.426470588235294, 'mrpc_en_eval_runtime': 1.8275, 'mrpc_en_eval_samples_per_second': 111.625, 'epoch': 2.17}                                                                                                                          
{'mrpc_en_eval_loss': 0.30966323614120483, 'mrpc_en_eval_f1': 68.48249027237354, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.426470588235294, 'mrpc_en_eval_runtime': 1.8275, 'mrpc_en_eval_samples_per_second': 111.625, 'epoch': 2.17, 'eval_average_metrics': 64.38830395971618}                                                                               

Expected behavior

Resuming from a checkpoint needs to get the same results as without

Thank you for your help @sgugger

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 24 (6 by maintainers)

Most upvoted comments

This might be due to the FP16 parameter. Could you check if you get the same result without FP16? The reason is due to the fact we don’t save the state of the gradient scaler in mixed precision training, which is another thing to restore to its state. Can make a PR to fix that tomorrow.

Hi I cannot really express how much I appreciate this. Thank you very much both for working on this. This would be wonderful to have resuming fixed in trainer. Thanks for your efforts.

Thank you for your detailed followup, @dorooddorood606. And sharing what experiments you have tried.

I agree that it’d be awesome to be able to resume as if there was no stopping.

Please give us some time, we are going discuss whether it is feasible to make it happen as there are so many moving parts to consider and if so will build this ground up.

We will keep you posted.

Dear @stas00

I appreciate your input on the issue of reproducibility from resuming from checkpoints a lot. I tried to follow your points to state it in a clearer way.

Problem statement

If a user train a model till some steps and then reload the model from a checkpoint, the results differs from the training the model without breaks.

How to reproduce the issue

Transformer version: I am using 4.6.0dev version of transformer

https://github.com/huggingface/transformers/commit/04ab2ca639ee6fd1002ce0a498d892245f0c9093

Please kindly clone this repository with a minimal example

git clone git@github.com:dorooddorood606/reproducibility.git

To run the codes, please kindly run this command, between the runs in every 50 steps after save of the model, kill the model for like 2-3 times. Please then compare the final results of running for the full iterations with resuming, with raining without any breaks. The results would differ.

TASK_NAME=mrpc
python run_glue.py   --model_name_or_path bert-base-cased   --task_name $TASK_NAME   --do_train   --do_eval   --max_seq_length 128   --per_device_train_batch_size 32   --learning_rate 2e-5   --num_train_epochs 2   --output_dir /temp/$TASK_NAME/  --eval_steps 50 --evaluation_strategy steps --load_best_model_at_end --fp16 --do_predict

Please let me know if you need any further information on this.

Which modifications done on Trainer class to make it reproducible:

I apply the following modifications to the trainer class:

  1. Following your suggestions. I save the random states and I reload them before reloading the checkpoint in the trainer class. Please see https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/trainer.py#L126 and https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/trainer.py#L200

  2. In each saving of checkpoints, I also save a copy of checkpoint in the output_dir, this is because I personally believe we need to also keep the last checkpoint to resume from in addition to keeping only checkpoint of the best model so far, to be able to continue training from the last state. Please see https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/trainer.py#L87

  3. I get the last checkpoint in run_glue.py based on the checkpoint saved in the main output_dir, please see https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/run_glue.py#L46

Larger impact of this issue

To me this issue with resuming from checkpoint, can also help other users and would be beneficial to all users who need to use this option. I appreciate a lot if you could sparse me some time from your precious time and help on this issue.

Dear @stas00 Thank you for the remind, I will follow the points you mentioned. I was thinking there is also a bug in the trainer as I was also observing it for the Bert-base model unchanged, but the randomness issue resolved with upgrading to 4.6.0 version of transformers.