fairseq: BART model does NOT work properly when trained from scratch

🐛 Bug

I trained a BART model from scratch (without the “–restore-file $PATH” argument) for the summarization task. During inference, the decoding seems wrong. Here are some output samples from the model:

s in Wales have been warned to be "inadequate" by the Welsh Government over the next five years.
s of a man who died after being hit by a car have been named by police.
ing the murder of a man who was found dead at a house in County Antrim has been jailed for life.
 Ched Evans has told a court that she would have to be a woman accused of raping a woman in the UK.
 Glamorgan has signed a new two-year contract with the Premier League club.

The beginning of each output sentence seems incomplete. Note that when the model is fine-tuned on the same dataset, everything is fine (with “–restore-file $PATH”).

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Train a BART model on the summarization dataset (XSum/CNNDM) from scratch.
TOTAL_NUM_UPDATES=15000
WARMUP_UPDATES=500      
LR=3e-05
MAX_TOKENS=2048
UPDATE_FREQ=2
SAVE_DIR=checkpoints/

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train $DATA_PATH \
    --save-dir $SAVE_DIR \
    --max-tokens $MAX_TOKENS \
    --task translation \
    --source-lang source --target-lang target \
    --truncate-source \
    --layernorm-embedding \
    --share-all-embeddings \
    --share-decoder-input-output-embed \
    --reset-optimizer --reset-dataloader --reset-meters \
    --required-batch-size-multiple 1 \
    --arch bart_large \
    --criterion label_smoothed_cross_entropy \
    --label-smoothing 0.1 \
    --dropout 0.1 --attention-dropout 0.1 \
    --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
    --clip-norm 0.1 \
    --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
    --fp16 --update-freq $UPDATE_FREQ \
    --skip-invalid-size-inputs-valid-test \
    --find-unused-parameters;
  1. Inference with the official code.
bart = BARTModel.from_pretrained(
    args.checkpoint_path,
    checkpoint_file="checkpoint_best.pt",
    data_name_or_path=args.data_path
)

bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('test.source') as source, open('test.hypo', 'w') as fout:
    sline = source.readline().strip()
    slines = [sline]
    for sline in source:
        if count % bsz == 0:
            with torch.no_grad():
                hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)

            for hypothesis in hypotheses_batch:
                fout.write(hypothesis + '\n')
                fout.flush()
            slines = []

        slines.append(sline.strip())
        count += 1
    if slines != []:
        hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
        for hypothesis in hypotheses_batch:
            fout.write(hypothesis + '\n')
            fout.flush()

Expected behavior

During inference, the decoding goes wrong and the decoded sentences are not completed.

Environment

  • fairseq Version: 0.9.0
  • PyTorch Version: 1.5.0
  • OS: Linux
  • How you installed fairseq: pip, source
  • Build command you used (if compiling from source): pip install --editable ./
  • Python version: 3.7.4
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: V100
  • Any other relevant information: None

Additional context

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 3
  • Comments: 16 (1 by maintainers)

Commits related to this issue

Most upvoted comments

Ah, you’re using the translation task rather than denoising. Probably there’s a mismatch with the hub interface relating to the beginning of sentence token, and you need to remove the <s> token from here: https://github.com/pytorch/fairseq/blob/f732b403ec15244c41a24b9e28d6c5a411a511df/fairseq/models/bart/hub_interface.py#L58

Hi @myleott, thanks for the replay. Yes, I am using the translation task. However, removing the "<s> " token doesn’t work for me. The outputs are the same.

I think the problem is caused by prefix_tokens. Force prefix_tokens to be None will solve the issue.

Hi. I’m running into the same problem with sample. I’m following the fine tuning recipe for BART using the latest version of fairseq with my own data (formatted as one sentence per line for the source and target files, without any special tokens added, for example, on line of each of these files could simply be “The cat sat on the mat.”).

@mcao610 Can you please elaborate what do you mean by setting prefix_tokens to None? prefix_tokens appears to be a string in hub_interface.py. Should you change the call to sample or hub_interface.py directly? Thank you.

I mean set the prefix_tokens variable in sequence_generator.py to None. For instance, you can add prefix_tokens = None before this for loop in sequence_generator.py :

https://github.com/pytorch/fairseq/blob/f732b403ec15244c41a24b9e28d6c5a411a511df/fairseq/sequence_generator.py#L293