transformers: __init__() missing 1 required positional argument: 'logits'

🐛 Bug

Information

Model I am using (Bert, XLNet …):

Language I am using the model on (English, Chinese …):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. python ./examples/text-classification/run_glue.py --model_name_or_path bert-base-uncased --task_name $TASK_NAME --do_train --do_eval --data_dir $GLUE_DIR/$TASK_NAME --max_seq_length 128 --per_device_eval_batch_size=2 --per_device_train_batch_size=2 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir /tmp/$TASK_NAME/

File “./examples/text-classification/run_glue.py”, line 246, in <module> main() File “./examples/text-classification/run_glue.py”, line 173, in main model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None File “/work/vnhh/anaconda3/lib/python3.6/site-packages/transformers/trainer.py”, line 499, in train tr_loss += self._training_step(model, inputs, optimizer) File “/work/vnhh/anaconda3/lib/python3.6/site-packages/transformers/trainer.py”, line 622, in _training_step outputs = model(**inputs) File “/work/vnhh/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 541, in call result = self.forward(*input, **kwargs) File “/work/vnhh/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py”, line 153, in forward return self.gather(outputs, self.output_device) File “/work/vnhh/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py”, line 165, in gather return gather(outputs, output_device, dim=self.dim) File “/work/vnhh/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py”, line 68, in gather res = gather_map(outputs) File “/work/vnhh/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py”, line 63, in gather_map return type(out)(map(gather_map, zip(*outputs))) TypeError: init() missing 1 required positional argument: ‘logits’

Expected behavior

It should be able to run and finish training

Environment info

  • transformers version: 3.0.2
  • Platform: Linux-4.4.0-165-generic-x86_64-with-debian-stretch-sid
  • Python version: 3.6.5
  • PyTorch version (GPU?): 1.3.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: <fill in>
  • Using distributed or parallel set-up in script?: <fill in> -tensorboardX: 1.9.0

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 2
  • Comments: 24 (18 by maintainers)

Commits related to this issue

Most upvoted comments

@csarron, this should fix it.

--- a/examples/question-answering/run_squad.py
+++ b/examples/question-answering/run_squad.py
@@ -199,6 +199,9 @@ def train(args, train_dataset, model, tokenizer):
                         {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                     )

+            if isinstance(model, torch.nn.DataParallel):
+                inputs["return_tuple"] = True
+
             outputs = model(**inputs)
             # model outputs are always tuple in transformers (see doc)
             loss = outputs[0]

It appears that this will now need to be added everywhere before model is invoked, and users will need to do that too should they code their own and intend to use DataParallel.

Surely, there must be a better way. I suppose that when this neat dataclass feature was added it wasn’t tested on nn.DataParallel. Perhaps best to back it out, figure out for pytorch to support dataclasses in scatter/gather and then put it back in with perhaps a monkeypatch for older pytorch versions. https://github.com/pytorch/pytorch/issues/41327

p.s. Note that the project’s scripts/modules don’t consistently import torch.nn as nn, so sometimes it’s torch.nn.DataParallel, whereas other times nn.DataParallel.

shall I get concerned about the reliability of the results?

If you’re referring to https://github.com/huggingface/transformers/pull/5685 commit, there is no reason to be concerned. There was no “functional” change per se, this is really sorting out the API - trying to make it consistent.