transformers: gpt2 generation crashes when using `past` for some output lengths

πŸ› Bug

Model I am using (Bert, XLNet…): gpt2

Language I am using the model on (English, Chinese…): English

The problem arise when using:

  • the official example scripts: (give details)
  • my own modified scripts: (give details)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details)

I am using run_generation.py to generate text from gpt2, but with the code slightly changed to make use of past to cache hidden states. It crashes in some cases.

To Reproduce

Steps to reproduce the behavior:

Apply the three line change to the run_generation.py script:

$ git checkout f88c104d8 .   # current head of master
$ git diff
diff --git a/examples/run_generation.py b/examples/run_generation.py
index 2d91766..bfbf68a 100644
--- a/examples/run_generation.py
+++ b/examples/run_generation.py
@@ -113,9 +113,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
     context = context.unsqueeze(0).repeat(num_samples, 1)
     generated = context
     with torch.no_grad():
+        past = None
         for _ in trange(length):

-            inputs = {'input_ids': generated}
+            inputs = {'input_ids': generated, 'past': past}
             if is_xlnet:
                 # XLNet is a direct (predict same token, not next token) and bi-directional model by default
                 # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
@@ -136,6 +137,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
                 inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)

             outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
+            past = outputs[1]
             next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)

             # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
$ python examples/run_generation.py --prompt "Who was Jim Henson ? Jim Henson was a" --model_type gpt2 --model_name_or_path gpt2 --length 40
...
11/06/2019 11:38:36 - INFO - __main__ -   Namespace(device=device(type='cpu'), length=40, model_name_or_path='gpt2', model_type='gpt2', n_gpu=0, no_cuda=False, num_samples=1, padding_text='', prompt='Who was Jim Henson ? Jim Henson was a', repetition_penalty=1.0, seed=42, stop_token=None, temperature=1.0, top_k=0, top_p=0.9, xlm_lang='')
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                | 35/40 [00:04<00:00,  8.68it/s]
Traceback (most recent call last):
  File "examples/run_generation.py", line 262, in <module>
    main()
  File "examples/run_generation.py", line 247, in main
    device=args.device,
  File "examples/run_generation.py", line 139, in sample_sequence
    outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
  File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/joelb/views/transformers/transformers/modeling_gpt2.py", line 546, in forward
    inputs_embeds=inputs_embeds)
  File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/joelb/views/transformers/transformers/modeling_gpt2.py", line 438, in forward
    position_embeds = self.wpe(position_ids)
  File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 114, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/functional.py", line 1484, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: index out of range: Tried to access index 1024 out of table with 1023 rows. at /Users/distiller/project/conda/conda-bld/pytorch_1570710797334/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418

Expected behavior

The crash does not occur with the default --length 20.

I’m trying to make this work for faster generation speed. So I tested the change with a smaller length that does not crash, but I observe the speed with caching is slightly slower than without.

Environment

Platform Darwin-18.7.0-x86_64-i386-64bit
Python 3.7.4 (default, Aug 13 2019, 15:17:50)
[Clang 4.0.1 (tags/RELEASE_401/final)]
PyTorch 1.3.0
Tensorflow 2.0.0

Additional context

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 15 (10 by maintainers)

Most upvoted comments

That’s a great new part of the documentation. Thank you!