transformers: gpt2 generation crashes when using `past` for some output lengths
π Bug
Model I am using (Bert, XLNetβ¦): gpt2
Language I am using the model on (English, Chineseβ¦): English
The problem arise when using:
- the official example scripts: (give details)
- my own modified scripts: (give details)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details)
I am using run_generation.py
to generate text from gpt2, but with the code slightly changed to make use of past
to cache hidden states. It crashes in some cases.
To Reproduce
Steps to reproduce the behavior:
Apply the three line change to the run_generation.py
script:
$ git checkout f88c104d8 . # current head of master
$ git diff
diff --git a/examples/run_generation.py b/examples/run_generation.py
index 2d91766..bfbf68a 100644
--- a/examples/run_generation.py
+++ b/examples/run_generation.py
@@ -113,9 +113,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
with torch.no_grad():
+ past = None
for _ in trange(length):
- inputs = {'input_ids': generated}
+ inputs = {'input_ids': generated, 'past': past}
if is_xlnet:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
@@ -136,6 +137,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
+ past = outputs[1]
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
$ python examples/run_generation.py --prompt "Who was Jim Henson ? Jim Henson was a" --model_type gpt2 --model_name_or_path gpt2 --length 40
...
11/06/2019 11:38:36 - INFO - __main__ - Namespace(device=device(type='cpu'), length=40, model_name_or_path='gpt2', model_type='gpt2', n_gpu=0, no_cuda=False, num_samples=1, padding_text='', prompt='Who was Jim Henson ? Jim Henson was a', repetition_penalty=1.0, seed=42, stop_token=None, temperature=1.0, top_k=0, top_p=0.9, xlm_lang='')
88%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 35/40 [00:04<00:00, 8.68it/s]
Traceback (most recent call last):
File "examples/run_generation.py", line 262, in <module>
main()
File "examples/run_generation.py", line 247, in main
device=args.device,
File "examples/run_generation.py", line 139, in sample_sequence
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/Users/joelb/views/transformers/transformers/modeling_gpt2.py", line 546, in forward
inputs_embeds=inputs_embeds)
File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/Users/joelb/views/transformers/transformers/modeling_gpt2.py", line 438, in forward
position_embeds = self.wpe(position_ids)
File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 114, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "/Users/joelb/anaconda3/envs/huggingface/lib/python3.7/site-packages/torch/nn/functional.py", line 1484, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: index out of range: Tried to access index 1024 out of table with 1023 rows. at /Users/distiller/project/conda/conda-bld/pytorch_1570710797334/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418
Expected behavior
The crash does not occur with the default --length 20
.
Iβm trying to make this work for faster generation speed. So I tested the change with a smaller length that does not crash, but I observe the speed with caching is slightly slower than without.
Environment
Platform Darwin-18.7.0-x86_64-i386-64bit
Python 3.7.4 (default, Aug 13 2019, 15:17:50)
[Clang 4.0.1 (tags/RELEASE_401/final)]
PyTorch 1.3.0
Tensorflow 2.0.0
Additional context
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 15 (10 by maintainers)
Thatβs a great new part of the documentation. Thank you!