transformers: Bert Checkpoint Breaks 3.02 -> 3.1.0 due to new buffer in BertEmbeddings
Hi,
Thanks for the great library. I noticed this line being added (https://github.com/huggingface/transformers/blob/v3.1.0/src/transformers/modeling_bert.py#L190) in the latest update.
It breaks checkpoints that were saved when this line wasn’t there.
Missing key(s) in state_dict: "generator_model.electra.embeddings.position_ids", "discriminator_model.electra.embeddings.position_ids".
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Reactions: 5
- Comments: 28 (10 by maintainers)
You can also use the
load_state_dict
method with thestrict
option set toFalse
:I think it’s safe to use
model.load_state_dict(state_dict, strict=False)
if the only missing information is theposition_ids
buffer. This tensor is indeed used by the model, but it’s just a constant tensor containing a list of integers from 0 to the maximum number of position embeddings. The tensor is first created in the constructor of theBertEmbeddings
class, in this line:https://github.com/huggingface/transformers/blob/fcf83011dffce3f2e8aad906f07c1ec14668f877/src/transformers/models/bert/modeling_bert.py#L182
As such, it’s not really part of the optimizable parameters of the model. This means that it doesn’t matter if
position_ids
is not available when callingload_state_dict
, because the line above will create it anyway in the constructor with the required values.Hi @LysandreJik
Thanks for the proposed solution.
In my case, I am using Pytorch Lightning which has its own saving and loading infrastructure. Thus the
from_pretrained
method can’t exactly be used.The
strict
flag is a good patch for now.I think, in general, when building on top of the library, for complex projects one cannot rely on
from_pretrained
, especially if using other ecosystems.I encountered the same issue. Old checkpoints (3.0.2) can not be loaded in (3.1.0) due to KeyError.
It also breaks to me. The attribute embedding.position_ids can’t be loaded if the model artifact is trained with v3.0.2. So it will raise an KeyError
@patrickvonplaten seems to break it for me: