transformers: Bug in trainer: substantially different results from restarting from a checkpoint and without
Environment info
transformers
version: 4.5.5- Platform: linux
- Python version: 3.7
- PyTorch version (GPU?): 1.8
- Tensorflow version (GPU?): -
- Using GPU in script?: -
- Using distributed or parallel set-up in script?: -
Who can help
@sgugger @patrickvonplaten, @patil-suraj
Information
- I am training T5 model and I am resuming the training from a checkpoint
- I have fixed the issue here https://github.com/huggingface/transformers/issues/11294 by freezing the parameters back right after this loads the model from the checkpoint
- I am using “evaluation_strategy”: “steps” and I evaluate the model every 10 steps with “save_total_limit”: 1
- I modified the save_checkpoint class as below to “save last copy of the model in output_dir” as one need to load a checkpoint from the place the model is left trained, and not from the checkpoint with best evaluation:
def _save_checkpoint(self, model, trial, metrics=None):
super()._save_checkpoint(model, trial, metrics)
# Saves the models checkpoints in the main folder.
if self.is_world_process_zero():
# remove the older global_steps.
global_steps = [str(x) for x in Path(self.args.output_dir).glob("global_step*")]
for global_step in global_steps:
shutil.rmtree(global_step)
self.save_model(self.args.output_dir)
if self.deepspeed:
self.deepspeed.save_checkpoint(self.args.output_dir)
else:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(self.args.output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(self.args.output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
self.state.save_to_json(os.path.join(self.args.output_dir, "trainer_state.json"))
then I find the last checkpoint to resume from it from the saved one in output directory as below:
def get_last_checkpoint(output_dir):
if os.path.exists(os.path.join(output_dir, 'pytorch_model.bin')):
return output_dir
return None
Here is the results without resume for 10 times evaluation:
{'loss': 5.0483, 'learning_rate': 6e-07, 'epoch': 0.02}
0%| | 10/60000 [00:07<11:11:04, 1.49it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.54it/s]
{'mrpc_en_eval_loss': 5.382528305053711, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.741, 'epoch': 0.22}
{'mrpc_en_eval_loss': 5.382528305053711, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.741, 'epoch': 0.22, 'eval_average_metrics': 0.0}
0%| | 20/60000 [00:20<11:57:29, 1.39it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.56it/s]
{'mrpc_en_eval_loss': 5.180729389190674, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8179, 'mrpc_en_eval_samples_per_second': 112.218, 'epoch': 0.43}
{'mrpc_en_eval_loss': 5.180729389190674, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8179, 'mrpc_en_eval_samples_per_second': 112.218, 'epoch': 0.43, 'eval_average_metrics': 0.0}
0%| | 30/60000 [00:33<12:01:13, 1.39it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.52it/s]
{'mrpc_en_eval_loss': 4.810805320739746, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.743, 'epoch': 0.65}
{'mrpc_en_eval_loss': 4.810805320739746, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0, 'mrpc_en_eval_runtime': 1.8421, 'mrpc_en_eval_samples_per_second': 110.743, 'epoch': 0.65, 'eval_average_metrics': 0.0}
0%| | 40/60000 [00:45<11:17:50, 1.47it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.54it/s]
{'mrpc_en_eval_loss': 4.203256607055664, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.031, 'mrpc_en_eval_samples_per_second': 100.441, 'epoch': 0.87}
{'mrpc_en_eval_loss': 4.203256607055664, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.031, 'mrpc_en_eval_samples_per_second': 100.441, 'epoch': 0.87, 'eval_average_metrics': 0.0}
0%| | 50/60000 [00:58<11:42:57, 1.42it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.39it/s]
{'mrpc_en_eval_loss': 3.262455463409424, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.1069, 'mrpc_en_eval_samples_per_second': 96.825, 'epoch': 1.09}
{'mrpc_en_eval_loss': 3.262455463409424, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.1069, 'mrpc_en_eval_samples_per_second': 96.825, 'epoch': 1.09, 'eval_average_metrics': 0.0}
0%|▏ | 60/60000 [01:13<11:57:15, 1.39it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 1.78it/s]
{'mrpc_en_eval_loss': 1.9655567407608032, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.053921568627451, 'mrpc_en_eval_runtime': 2.8657, 'mrpc_en_eval_samples_per_second': 71.186, 'epoch': 1.3}
{'mrpc_en_eval_loss': 1.9655567407608032, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.053921568627451, 'mrpc_en_eval_runtime': 2.8657, 'mrpc_en_eval_samples_per_second': 71.186, 'epoch': 1.3, 'eval_average_metrics': 0.24509803921568626}
0%|▏ | 70/60000 [01:27<12:14:11, 1.36it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.08it/s]
{'mrpc_en_eval_loss': 0.7519775032997131, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.9411764705882355, 'mrpc_en_eval_runtime': 2.6193, 'mrpc_en_eval_samples_per_second': 77.884, 'epoch': 1.52}
{'mrpc_en_eval_loss': 0.7519775032997131, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.9411764705882355, 'mrpc_en_eval_runtime': 2.6193, 'mrpc_en_eval_samples_per_second': 77.884, 'epoch': 1.52, 'eval_average_metrics': 26.60441477204379}
0%|▏ | 80/60000 [01:41<12:02:22, 1.38it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.60it/s]
{'mrpc_en_eval_loss': 0.4142318665981293, 'mrpc_en_eval_f1': 75.62500000000001, 'mrpc_en_eval_accuracy': 61.76470588235294, 'mrpc_en_eval_gen_len': 2.1176470588235294, 'mrpc_en_eval_runtime': 1.7878, 'mrpc_en_eval_samples_per_second': 114.109, 'epoch': 1.74}
{'mrpc_en_eval_loss': 0.4142318665981293, 'mrpc_en_eval_f1': 75.62500000000001, 'mrpc_en_eval_accuracy': 61.76470588235294, 'mrpc_en_eval_gen_len': 2.1176470588235294, 'mrpc_en_eval_runtime': 1.7878, 'mrpc_en_eval_samples_per_second': 114.109, 'epoch': 1.74, 'eval_average_metrics': 68.69485294117648}
0%|▏ | 90/60000 [01:54<11:41:23, 1.42it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.54it/s]
{'mrpc_en_eval_loss': 0.3786551058292389, 'mrpc_en_eval_f1': 51.18483412322274, 'mrpc_en_eval_accuracy': 49.50980392156863, 'mrpc_en_eval_gen_len': 2.6519607843137254, 'mrpc_en_eval_runtime': 1.8265, 'mrpc_en_eval_samples_per_second': 111.69, 'epoch': 1.96}
{'mrpc_en_eval_loss': 0.3786551058292389, 'mrpc_en_eval_f1': 51.18483412322274, 'mrpc_en_eval_accuracy': 49.50980392156863, 'mrpc_en_eval_gen_len': 2.6519607843137254, 'mrpc_en_eval_runtime': 1.8265, 'mrpc_en_eval_samples_per_second': 111.69, 'epoch': 1.96, 'eval_average_metrics': 50.34731902239569}
0%|▏ | 100/60000 [02:07<12:01:27, 1.38it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.58it/s]
{'mrpc_en_eval_loss': 0.29472649097442627, 'mrpc_en_eval_f1': 71.01449275362319, 'mrpc_en_eval_accuracy': 60.78431372549019, 'mrpc_en_eval_gen_len': 2.3333333333333335, 'mrpc_en_eval_runtime': 1.812, 'mrpc_en_eval_samples_per_second': 112.581, 'epoch': 2.17}
{'mrpc_en_eval_loss': 0.29472649097442627, 'mrpc_en_eval_f1': 71.01449275362319, 'mrpc_en_eval_accuracy': 60.78431372549019, 'mrpc_en_eval_gen_len': 2.3333333333333335, 'mrpc_en_eval_runtime': 1.812, 'mrpc_en_eval_samples_per_second': 112.581, 'epoch': 2.17, 'eval_average_metrics': 65.89940323955669}
Now lets resume from step = 40, while the first 40 steps would get the same results, after resuming the results differ a lot:
0%| | 40/60000 [00:07<9:49:41, 1.69it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.62it/s]
{'mrpc_en_eval_loss': 4.203643321990967, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.0033, 'mrpc_en_eval_samples_per_second': 101.834, 'epoch': 0.87}
{'mrpc_en_eval_loss': 4.203643321990967, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.0033, 'mrpc_en_eval_samples_per_second': 101.834, 'epoch': 0.87, 'eval_average_metrics': 0.0}
0%| | 50/60000 [00:21<12:09:50, 1.37it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.30it/s]
{'mrpc_en_eval_loss': 3.2706634998321533, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.2048, 'mrpc_en_eval_samples_per_second': 92.524, 'epoch': 1.09}
{'mrpc_en_eval_loss': 3.2706634998321533, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.0, 'mrpc_en_eval_gen_len': 3.0098039215686274, 'mrpc_en_eval_runtime': 2.2048, 'mrpc_en_eval_samples_per_second': 92.524, 'epoch': 1.09, 'eval_average_metrics': 0.0}
0%|▏ | 60/60000 [00:35<12:27:28, 1.34it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.54it/s]
{'mrpc_en_eval_loss': 1.9863247871398926, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.019607843137255, 'mrpc_en_eval_runtime': 2.4126, 'mrpc_en_eval_samples_per_second': 84.557, 'epoch': 1.3}
{'mrpc_en_eval_loss': 1.9863247871398926, 'mrpc_en_eval_f1': 0.0, 'mrpc_en_eval_accuracy': 0.49019607843137253, 'mrpc_en_eval_gen_len': 3.019607843137255, 'mrpc_en_eval_runtime': 2.4126, 'mrpc_en_eval_samples_per_second': 84.557, 'epoch': 1.3, 'eval_average_metrics': 0.24509803921568626}
0%|▏ | 70/60000 [00:49<12:02:36, 1.38it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.07it/s]
{'mrpc_en_eval_loss': 0.7721647620201111, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.946078431372549, 'mrpc_en_eval_runtime': 2.5655, 'mrpc_en_eval_samples_per_second': 79.518, 'epoch': 1.52}
{'mrpc_en_eval_loss': 0.7721647620201111, 'mrpc_en_eval_f1': 18.404907975460123, 'mrpc_en_eval_accuracy': 34.80392156862745, 'mrpc_en_eval_gen_len': 2.946078431372549, 'mrpc_en_eval_runtime': 2.5655, 'mrpc_en_eval_samples_per_second': 79.518, 'epoch': 1.52, 'eval_average_metrics': 26.60441477204379}
0%|▏ | 80/60000 [01:02<12:08:06, 1.37it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.55it/s]
{'mrpc_en_eval_loss': 0.42692506313323975, 'mrpc_en_eval_f1': 74.28571428571428, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.142156862745098, 'mrpc_en_eval_runtime': 1.8243, 'mrpc_en_eval_samples_per_second': 111.824, 'epoch': 1.74}
{'mrpc_en_eval_loss': 0.42692506313323975, 'mrpc_en_eval_f1': 74.28571428571428, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.142156862745098, 'mrpc_en_eval_runtime': 1.8243, 'mrpc_en_eval_samples_per_second': 111.824, 'epoch': 1.74, 'eval_average_metrics': 67.28991596638654}
0%|▏ | 90/60000 [01:16<12:00:53, 1.39it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.50it/s]
{'mrpc_en_eval_loss': 0.39015302062034607, 'mrpc_en_eval_f1': 45.685279187817265, 'mrpc_en_eval_accuracy': 47.549019607843135, 'mrpc_en_eval_gen_len': 2.7205882352941178, 'mrpc_en_eval_runtime': 1.856, 'mrpc_en_eval_samples_per_second': 109.915, 'epoch': 1.96}
{'mrpc_en_eval_loss': 0.39015302062034607, 'mrpc_en_eval_f1': 45.685279187817265, 'mrpc_en_eval_accuracy': 47.549019607843135, 'mrpc_en_eval_gen_len': 2.7205882352941178, 'mrpc_en_eval_runtime': 1.856, 'mrpc_en_eval_samples_per_second': 109.915, 'epoch': 1.96, 'eval_average_metrics': 46.617149397830204}
0%|▏ | 100/60000 [01:31<12:02:17, 1.38it/s]***** Running Evaluation *****
Num examples = 204
Batch size = 80
### n_samples 204███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00, 2.55it/s]
{'mrpc_en_eval_loss': 0.30966323614120483, 'mrpc_en_eval_f1': 68.48249027237354, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.426470588235294, 'mrpc_en_eval_runtime': 1.8275, 'mrpc_en_eval_samples_per_second': 111.625, 'epoch': 2.17}
{'mrpc_en_eval_loss': 0.30966323614120483, 'mrpc_en_eval_f1': 68.48249027237354, 'mrpc_en_eval_accuracy': 60.29411764705882, 'mrpc_en_eval_gen_len': 2.426470588235294, 'mrpc_en_eval_runtime': 1.8275, 'mrpc_en_eval_samples_per_second': 111.625, 'epoch': 2.17, 'eval_average_metrics': 64.38830395971618}
Expected behavior
Resuming from a checkpoint needs to get the same results as without
Thank you for your help @sgugger
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 24 (6 by maintainers)
This might be due to the FP16 parameter. Could you check if you get the same result without FP16? The reason is due to the fact we don’t save the state of the gradient scaler in mixed precision training, which is another thing to restore to its state. Can make a PR to fix that tomorrow.
Hi I cannot really express how much I appreciate this. Thank you very much both for working on this. This would be wonderful to have resuming fixed in trainer. Thanks for your efforts.
@sgugger is working on it in https://github.com/huggingface/transformers/pull/11582
Thank you for your detailed followup, @dorooddorood606. And sharing what experiments you have tried.
I agree that it’d be awesome to be able to resume as if there was no stopping.
Please give us some time, we are going discuss whether it is feasible to make it happen as there are so many moving parts to consider and if so will build this ground up.
We will keep you posted.
Dear @stas00
I appreciate your input on the issue of reproducibility from resuming from checkpoints a lot. I tried to follow your points to state it in a clearer way.
Problem statement
If a user train a model till some steps and then reload the model from a checkpoint, the results differs from the training the model without breaks.
How to reproduce the issue
Transformer version: I am using 4.6.0dev version of transformer
Please kindly clone this repository with a minimal example
To run the codes, please kindly run this command, between the runs in every 50 steps after save of the model, kill the model for like 2-3 times. Please then compare the final results of running for the full iterations with resuming, with
raining without any breaks
. The results would differ.Please let me know if you need any further information on this.
Which modifications done on Trainer class to make it reproducible:
I apply the following modifications to the trainer class:
Following your suggestions. I save the random states and I reload them before reloading the checkpoint in the trainer class. Please see https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/trainer.py#L126 and https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/trainer.py#L200
In each saving of checkpoints, I also save a copy of checkpoint in the output_dir, this is because I personally believe we need to also keep the last checkpoint to resume from in addition to keeping only checkpoint of the best model so far, to be able to continue training from the last state. Please see https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/trainer.py#L87
I get the last checkpoint in run_glue.py based on the checkpoint saved in the main output_dir, please see https://github.com/dorooddorood606/reproducibility/blob/f5902af4669bba8aaee326efdb0cd459e25be675/run_glue.py#L46
Larger impact of this issue
To me this issue with resuming from checkpoint, can also help other users and would be beneficial to all users who need to use this option. I appreciate a lot if you could sparse me some time from your precious time and help on this issue.
Dear @stas00 Thank you for the remind, I will follow the points you mentioned. I was thinking there is also a bug in the trainer as I was also observing it for the Bert-base model unchanged, but the randomness issue resolved with upgrading to 4.6.0 version of transformers.