transformers: Mismatch of the mask token id of BART between fairseq and huggingface

šŸ› Bug

The mask token id of BART is different between fairseq (torch.hub) and huggingface, and this discrepancy leads to different results in mask_filling. So I wonder which token id is actually correct.

(After checking the norm of the embedding at each mask token id, I feel that torch.hub might be correct. I have posted the same issue at fairseq github and been waiting for the reply.)

To Reproduce

Code sample

from transformers import BartForConditionalGeneration, BartTokenizer
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base", force_bos_token_to_be_generated=True)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
assert tokenizer.mask_token_id == 50264
example_english_phrase = "<mask> cat is <mask>."
batch = tokenizer(example_english_phrase, return_tensors='pt')
generated_ids = model.generate(batch['input_ids'],return_dict_in_generate = True, num_beams=10, num_return_sequences=1, output_scores = True)
print(" ".join(tokenizer.convert_ids_to_tokens(generated_ids[0][0])))

# </s> <s> This Ä cat Ä is Ä adorable . </s>**

import torch
bart = torch.hub.load('pytorch/fairseq', 'bart.base')
bart.eval()
assert bart.task.source_dictionary.indices["<mask>"]  == 51200
assert bart.task.source_dictionary.indices["madeupword0003"] == tokenizer.mask_token_id
# Somehow the huggingface model has a smaller vocab size, and 51200 is out of index
assert len(model.model.encoder.embed_tokens.weight) == 50265
assert len(bart.model.encoder.embed_tokens.weight) == 51201
# But the embedding at tokenizer.mask_token_id is the same between the two models
assert all(bart.model.encoder.embed_tokens.weight[tokenizer.mask_token_id] == model.model.encoder.embed_tokens.weight[tokenizer.mask_token_id])

def fill_mask(
        model,
        masked_inputs,
        topk = 1,
        match_source_len = False,
        masked_token = '<mask>',
        **generate_kwargs
):
    batch_tokens = []
    for masked_input in masked_inputs:
        assert masked_token in masked_input, \
            "please add one {} token for the input".format(masked_token)
        text_spans = masked_input.split(masked_token)
        text_spans_bpe = (' {0} '.format(masked_token)).join(
            [model.bpe.encode(text_span.rstrip()) for text_span in text_spans]
        ).strip()
        tokens = model.task.source_dictionary.encode_line(
            '<s> ' + text_spans_bpe + ' </s>',
            append_eos=False,
            add_if_not_exist=False,
        ).long()
        batch_tokens.append(tokens)
    generate_kwargs['beam'] = max(
        topk,
        generate_kwargs.get('beam', -1),
    )
    generate_kwargs['match_source_len'] = match_source_len
    batch_hypos = model.generate(batch_tokens, **generate_kwargs)
    return batch_hypos


masked_inputs=[example_english_phrase]
generate_kwargs = {}
generate_kwargs['beam'] = 10
generate_kwargs['match_source_len'] = False
batch_hypos = fill_mask(bart,masked_inputs, **generate_kwargs)
print(" ".join(tokenizer.convert_ids_to_tokens(batch_hypos[0][0]["tokens"])))

# <s> The Ä cat Ä is Ä dead . </s>**

#### replace <mask> with madeupword0003 ####
example_english_phrase = "madeupword0003 cat is madeupword0003."
masked_inputs=[example_english_phrase]
batch_hypos = fill_mask(bart,masked_inputs, masked_token = "madeupword0003", **generate_kwargs)
print(" ".join(tokenizer.convert_ids_to_tokens(batch_hypos[0][0]["tokens"])))

# <s> This Ä cat Ä is Ä adorable . </s>

Environment

  • PyTorch Version: 1.5.1+cu101
  • OS (e.g., Linux): Linux
  • Python version: 3.6.10
  • transformers version: 4.2.1
  • CUDA version: 10.1

Additional context

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 26 (13 by maintainers)

Most upvoted comments

I’m okay with updating the weights in a new commit since this fixes issues, as people can still revert to the previous commit if they really need to.

I’ve updated the weights for all 3 checkpoints pt, tf, flax https://huggingface.co/facebook/bart-base/tree/main

this issue is only associated with bart-base, bart-large does not have this problem, so no need to change the weights there 😃

Here is a temporary solution. I replace HF’s <mask> embedding with fairseq’s <mask> embedding. Here is the model https://huggingface.co/liangtaiwan/bart-base-correct-mask-embedding

You can verify the new weight is corrected by the following script.

import torch
from transformers import BartModel, BartTokenizer

# fsq bart=base

bart = torch.hub.load('pytorch/fairseq', 'bart.base')
mask_token_id = bart.task.source_dictionary.indices["<mask>"]

mask_token_weight_fairseq = bart.model.encoder.embed_tokens.weight[mask_token_id].detach()

# my bart-base

hf_tok = BartTokenizer.from_pretrained("liangtaiwan/bart-base-correct-mask-embedding")
mask_token_id_hf = hf_tok.mask_token_id
hf_model = BartModel.from_pretrained("liangtaiwan/bart-base-correct-mask-embedding")

mask_token_weight_hf = hf_model.encoder.embed_tokens.weight[mask_token_id_hf]
assert torch.equal(mask_token_weight_hf - mask_token_weight_fairseq)

# HF bart-base
hf_original_model = BartModel.from_pretrained("facebook/bart-base")

hf_original_model_state_dict = hf_original_model.state_dict()
hf_model_state_dict = hf_model.state_dict()
embeddings = ["shared.weight", "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

# check weight 
for k in hf_model_state_dict.keys():
    if k in embeddings:
        continue
    assert torch.equal(hf_model_state_dict[k], hf_original_model_state_dict[k])

# check embedding
for k in embeddings:
    assert torch.equal(hf_model_state_dict[k][:-1], hf_original_model_state_dict[k][:-1])

However, I did some prompt language model experiments. The results are almost identical. The result of HF’s one is even better sometimes.