transformers: TensorFlow "predict" returns empty output with MirroredStrategy
I’m trying to use the predict method of the Keras TensorFlow API but it returns an empty output despite the input is being processed. Calling the model seems to work.
EDIT: the predict method works correctly if the model is loaded with single GPu strategy.
Environment info
transformersversion:4.5.1- Platform: Linux CentOS 8.1
- Python version:
3.7.10 - PyTorch version (GPU?): -
- Tensorflow version (GPU?):
2.3.2(True) - Using GPU in script?: yes
- Using distributed or parallel set-up in script?: multi-gpu on a single machine
Who can help
Information
Model I am using: Bert
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
from transformers import BertTokenizerFast, TFBertForSequenceClassification
import tensorflow as tf
strategy = tf.distribute.MirroredStrategy()
#strategy = tf.distribute.OneDeviceStrategy("/gpu:0")
with strategy.scope():
tf_model = TFBertForSequenceClassification.from_pretrained('bert-base-cased')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
inputs = tokenizer('This is a test', 'Esto es una prueba',
return_tensors='tf', max_length=200,
padding='max_length', truncation=True,
return_attention_mask=True,
return_token_type_ids=False)
print(tf_model.predict([inputs["input_ids"], inputs["attention_mask"]],
verbose=1))
print(tf_model([inputs["input_ids"], inputs["attention_mask"]]))
All model checkpoint layers were used when initializing TFBertForSequenceClassification.
Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
WARNING:tensorflow:From /venv/lib/python3.7/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
1/1 [==============================] - 0s 241us/step
TFSequenceClassifierOutput(loss=None, logits=None, hidden_states=None, attentions=None)
TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[-0.47814545, 0.35146457]], dtype=float32)>, hidden_states=None, attentions=None)
Expected behavior
Output should be the same as when model is being called.
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 2
- Comments: 23 (17 by maintainers)
Hi, Any updates about this issue?
Putting this here as a writeup of what we know so far:
The issue is not caused by returning an
OrderedDict, but instead because we return aTFBaseModelOutput, which is a subclass ofOrderedDictdecorated with dataclass. Refer to the code here.If we just return a dict,
OrderedDictorModelOutput(the parent class forTFBaseModelOutput, subclassed fromOrderedDict), everything works okay. Therefore the central issue is this data class, which will probably need to be removed. We’re looking at how we can do that now!I’ve managed to reproduce this but I’m very confused about the cause, especially because I’m pretty sure I’ve used model.predict with MirroredStrategy in our codebase before.
I’ve tested your code snippet with a standard RNN instead of BERT and confirmed that it works fine, and I tried distilbert instead of BERT and the problem remained, so the problem does seem to be the combination of MirroredStrategy and our models.
I’m going to keep poking around at this, but if you discover anything else that might help me figure out what’s going on, please let me know!
Working with 1024 samples and 8 batch size per gpu.
Hi, just keeping this issue alive! I’ve traced the issue to the way we return our values from the
call()methods - I think Keras doesn’t like the thing we do with a subclassed OrderedDict. We’re going to reach out to our contacts at Google in the next couple of days and figure out what the best approach is - whether we need to refactor that totally, or if there’s an easy workaround.@jmwoloso Sure, if you’d like! If you have any questions along the way, feel free to ask.
definitely looking forward to a fix for this. how can we help @Rocketknight1?
Update: This bug appears in our
run_text_classification.pyscript too, again only when using predict(). I’m investigating.