transformers: Bert Checkpoint Breaks 3.02 -> 3.1.0 due to new buffer in BertEmbeddings

Hi,

Thanks for the great library. I noticed this line being added (https://github.com/huggingface/transformers/blob/v3.1.0/src/transformers/modeling_bert.py#L190) in the latest update.

It breaks checkpoints that were saved when this line wasn’t there.

	Missing key(s) in state_dict: "generator_model.electra.embeddings.position_ids", "discriminator_model.electra.embeddings.position_ids". 

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 5
  • Comments: 28 (10 by maintainers)

Most upvoted comments

You can also use the load_state_dict method with the strict option set to False:

model.load_state_dict(state_dict, strict=False)

I think it’s safe to use model.load_state_dict(state_dict, strict=False) if the only missing information is the position_ids buffer. This tensor is indeed used by the model, but it’s just a constant tensor containing a list of integers from 0 to the maximum number of position embeddings. The tensor is first created in the constructor of the BertEmbeddings class, in this line:

https://github.com/huggingface/transformers/blob/fcf83011dffce3f2e8aad906f07c1ec14668f877/src/transformers/models/bert/modeling_bert.py#L182

As such, it’s not really part of the optimizable parameters of the model. This means that it doesn’t matter if position_ids is not available when calling load_state_dict, because the line above will create it anyway in the constructor with the required values.

Hi @LysandreJik

Thanks for the proposed solution.

In my case, I am using Pytorch Lightning which has its own saving and loading infrastructure. Thus the from_pretrained method can’t exactly be used.

The strict flag is a good patch for now.

I think, in general, when building on top of the library, for complex projects one cannot rely on from_pretrained, especially if using other ecosystems.

I encountered the same issue. Old checkpoints (3.0.2) can not be loaded in (3.1.0) due to KeyError.

It also breaks to me. The attribute embedding.position_ids can’t be loaded if the model artifact is trained with v3.0.2. So it will raise an KeyError

@patrickvonplaten seems to break it for me:


16:43:52
Traceback (most recent call last):

16:43:52
File "/opt/conda/envs/py36/bin/transformervae", line 33, in <module>

16:43:52
sys.exit(load_entry_point('exs-transformervae', 'console_scripts', 'transformervae')())

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 829, in __call__

16:43:52
return self.main(*args, **kwargs)

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 782, in main

16:43:52
rv = self.invoke(ctx)

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 1259, in invoke

16:43:52
return _process_result(sub_ctx.command.invoke(sub_ctx))

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 1066, in invoke

16:43:52
return ctx.invoke(self.callback, **ctx.params)

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/click/core.py", line 610, in invoke

16:43:52
return callback(*args, **kwargs)

16:43:52
File "/app/transformervae/cli.py", line 355, in train

16:43:52
model = model_cls(hparams, pretrained_model=pretrained_model_path_or_config)

16:43:52
File "/app/transformervae/models/regression.py", line 35, in __init__

16:43:52
pretrained_model,

16:43:52
File "/app/transformervae/models/finetuning_model.py", line 37, in __init__

16:43:52
self.encoder, self.tokenizer = self.load_pretrained_encoder(pretrained_model)

16:43:52
File "/app/transformervae/models/finetuning_model.py", line 89, in load_pretrained_encoder

16:43:52
pl_model = AutoModel.load(pretrained_model)

16:43:52
File "/app/transformervae/models/automodel.py", line 98, in load

16:43:52
return model_cls.load(path)

16:43:52
File "/app/transformervae/models/base.py", line 229, in load

16:43:52
return cls.load_from_checkpoint(filepath)

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 169, in load_from_checkpoint

16:43:52
model = cls._load_model_state(checkpoint, *args, **kwargs)

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/pytorch_lightning/core/saving.py", line 207, in _load_model_state

16:43:52
model.load_state_dict(checkpoint['state_dict'])

16:43:52
File "/opt/conda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict

16:43:52
self.__class__.__name__, "\n\t".join(error_msgs)))

16:43:52
RuntimeError: Error(s) in loading state_dict for ElectraLanguageModel:

16:43:52
Missing key(s) in state_dict: "generator_model.electra.embeddings.position_ids", "discriminator_model.electra.embeddings.position_ids".