transformers: cast_bool_to_primitive breaks TensorFlow graph support.
๐ Bug
To reproduce
import transformers
bert = tf.function(transformers.TFBertForMaskedLM.from_pretrained('bert-base-uncased'))
for i in range(2):
(_, hidden_state) = bert(tf.constant([[10,11,12]]), output_hidden_states=True)
print(f'computed {i}')
Errors with
ValueError: not enough values to unpack (expected 2, got 1)
Expected behavior
computed 1
computed 2
Same result as if tf.function
was not used.
Environment info
Example environment : https://colab.research.google.com/gist/AndreasMadsen/593df94a3319dee58bba33a26efedeb3/untitled6.ipynb
transformers
version: 3.0.2- Platform: Linux-4.19.104ยฑx86_64-with-Ubuntu-18.04-bionic
- Python version: 3.6.9
- PyTorch version (GPU?): 1.5.1+cu101 (False)
- Tensorflow version (GPU?): 2.2.0 (False)
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
Details
The bug happens due to cast_bool_to_primitive
, that was introduced in https://github.com/huggingface/transformers/commit/6e603cb7892b49a2cbbc10ba859759f92c3fb7a6. Before that, it was possible to get the hidden_states
from Bert in TensorFlow graph/function mode.
Generally speaking, casting TensorFlow tensors to primitives is not a good practice, as it only works in eager mode. It is also completely unnecessary in this case, as using if bool_tensor_scalar:
works perfectly fine.
def print_bool(x):
if x:
print('True')
else:
print('False')
print_bool_graph = tf.function(print_bool)
print('eager:')
print_bool(True) # Prints True
print_bool(False) # Prints False
print_bool(tf.constant(True)) # Prints True
print_bool(tf.constant(False)) # Prints False
print('')
print('graph:')
print_bool_graph(True) # Prints True
print_bool_graph(False) # Prints False
print_bool_graph(tf.constant(True)) # Prints True
print_bool_graph(tf.constant(False)) # Prints False
I can see there are some cases where defaults are used. The right way to handle that is to implement the default handling upstream in the first call()
method. A lesser way would be to implement it as:
def cast_bool_to_primitive(x, default_value=False):
if x is None:
return default_value
return x
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 19 (15 by maintainers)
Ok, thanks for the hints, I will review that part! I should be able to start working on it from tomorrow or Thursday.
Hi @jplu, Iโm sorry, but I doubt #5468 will fix the issue. Fundamentally speaking casting to primitives is not a good practice in TensorFlow, as it invalidates the use of
@tf.function
and is generally unnecessary as described above. Casting to primitives is, in my experience, just never the correct solution in TensorFlow.I do think #5468 mitigates the issue, which is maybe where the confusion is coming from. This is because, the models will now correctly default to the
config
object whenoutput_hidden_states=True
is not specified as an input. In those cases object property is never cast to a tensor to begin with, therefore the@tf.function
graph will be statically compiled to always output thehidden_states
, as intended.However, the behavior is different when
output_hidden_states=True
is specified as an input, as it will be cast to a Tensor when it becomes part of theinputs
argument incall()
. After that, it is not possible to convert back to a primitive, as that invalidates@tf.function
.If you insist on keeping it as a primitive, the best solution might be to specify it as an aux-input, similar to
training
andmask
in akeras.layers.Layer
, as they donโt get converted the same way. Iโm not familiar enough with the Keras internals to know the details here, and I think it might also be incompatible withcompute_output_shape
etc.BTW, in the keras RNN layers, hidden_state is only specified in the constructor, properly because it can get a bit messy having to specify it in the
inputs
, but I donโt see anything fundamentally wrong with specifying it ininputs
.Should be fixed in the PR https://github.com/huggingface/transformers/pull/5468