transformers: Mismatch of the mask token id of BART between fairseq and huggingface
š Bug
The mask token id of BART is different between fairseq (torch.hub) and huggingface, and this discrepancy leads to different results in mask_filling. So I wonder which token id is actually correct.
(After checking the norm of the embedding at each mask token id, I feel that torch.hub might be correct. I have posted the same issue at fairseq github and been waiting for the reply.)
To Reproduce
Code sample
from transformers import BartForConditionalGeneration, BartTokenizer
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base", force_bos_token_to_be_generated=True)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
assert tokenizer.mask_token_id == 50264
example_english_phrase = "<mask> cat is <mask>."
batch = tokenizer(example_english_phrase, return_tensors='pt')
generated_ids = model.generate(batch['input_ids'],return_dict_in_generate = True, num_beams=10, num_return_sequences=1, output_scores = True)
print(" ".join(tokenizer.convert_ids_to_tokens(generated_ids[0][0])))
# </s> <s> This Ä cat Ä is Ä adorable . </s>**
import torch
bart = torch.hub.load('pytorch/fairseq', 'bart.base')
bart.eval()
assert bart.task.source_dictionary.indices["<mask>"] == 51200
assert bart.task.source_dictionary.indices["madeupword0003"] == tokenizer.mask_token_id
# Somehow the huggingface model has a smaller vocab size, and 51200 is out of index
assert len(model.model.encoder.embed_tokens.weight) == 50265
assert len(bart.model.encoder.embed_tokens.weight) == 51201
# But the embedding at tokenizer.mask_token_id is the same between the two models
assert all(bart.model.encoder.embed_tokens.weight[tokenizer.mask_token_id] == model.model.encoder.embed_tokens.weight[tokenizer.mask_token_id])
def fill_mask(
model,
masked_inputs,
topk = 1,
match_source_len = False,
masked_token = '<mask>',
**generate_kwargs
):
batch_tokens = []
for masked_input in masked_inputs:
assert masked_token in masked_input, \
"please add one {} token for the input".format(masked_token)
text_spans = masked_input.split(masked_token)
text_spans_bpe = (' {0} '.format(masked_token)).join(
[model.bpe.encode(text_span.rstrip()) for text_span in text_spans]
).strip()
tokens = model.task.source_dictionary.encode_line(
'<s> ' + text_spans_bpe + ' </s>',
append_eos=False,
add_if_not_exist=False,
).long()
batch_tokens.append(tokens)
generate_kwargs['beam'] = max(
topk,
generate_kwargs.get('beam', -1),
)
generate_kwargs['match_source_len'] = match_source_len
batch_hypos = model.generate(batch_tokens, **generate_kwargs)
return batch_hypos
masked_inputs=[example_english_phrase]
generate_kwargs = {}
generate_kwargs['beam'] = 10
generate_kwargs['match_source_len'] = False
batch_hypos = fill_mask(bart,masked_inputs, **generate_kwargs)
print(" ".join(tokenizer.convert_ids_to_tokens(batch_hypos[0][0]["tokens"])))
# <s> The Ä cat Ä is Ä dead . </s>**
#### replace <mask> with madeupword0003 ####
example_english_phrase = "madeupword0003 cat is madeupword0003."
masked_inputs=[example_english_phrase]
batch_hypos = fill_mask(bart,masked_inputs, masked_token = "madeupword0003", **generate_kwargs)
print(" ".join(tokenizer.convert_ids_to_tokens(batch_hypos[0][0]["tokens"])))
# <s> This Ä cat Ä is Ä adorable . </s>
Environment
- PyTorch Version: 1.5.1+cu101
- OS (e.g., Linux): Linux
- Python version: 3.6.10
- transformers version: 4.2.1
- CUDA version: 10.1
Additional context
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 26 (13 by maintainers)
Iām okay with updating the weights in a new commit since this fixes issues, as people can still revert to the previous commit if they really need to.
Iāve updated the weights for all 3 checkpoints pt, tf, flax https://huggingface.co/facebook/bart-base/tree/main
this issue is only associated with
bart-base
,bart-large
does not have this problem, so no need to change the weights there šHere is a temporary solution. I replace HFās <mask> embedding with fairseqās <mask> embedding. Here is the model https://huggingface.co/liangtaiwan/bart-base-correct-mask-embedding
You can verify the new weight is corrected by the following script.
However, I did some prompt language model experiments. The results are almost identical. The result of HFās one is even better sometimes.