transformers: cast_bool_to_primitive breaks TensorFlow graph support.

๐Ÿ› Bug

To reproduce

import transformers
bert = tf.function(transformers.TFBertForMaskedLM.from_pretrained('bert-base-uncased'))

for i in range(2):
    (_, hidden_state) = bert(tf.constant([[10,11,12]]), output_hidden_states=True)
    print(f'computed {i}')

Errors with

ValueError: not enough values to unpack (expected 2, got 1)

Expected behavior

computed 1
computed 2

Same result as if tf.function was not used.

Environment info

Example environment : https://colab.research.google.com/gist/AndreasMadsen/593df94a3319dee58bba33a26efedeb3/untitled6.ipynb

  • transformers version: 3.0.2
  • Platform: Linux-4.19.104ยฑx86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.5.1+cu101 (False)
  • Tensorflow version (GPU?): 2.2.0 (False)
  • Using GPU in script?: <fill in>
  • Using distributed or parallel set-up in script?: <fill in>

Details

The bug happens due to cast_bool_to_primitive, that was introduced in https://github.com/huggingface/transformers/commit/6e603cb7892b49a2cbbc10ba859759f92c3fb7a6. Before that, it was possible to get the hidden_states from Bert in TensorFlow graph/function mode.

Generally speaking, casting TensorFlow tensors to primitives is not a good practice, as it only works in eager mode. It is also completely unnecessary in this case, as using if bool_tensor_scalar: works perfectly fine.

def print_bool(x):
   if x:
       print('True')
   else:
       print('False')

print_bool_graph = tf.function(print_bool)

print('eager:')
print_bool(True) # Prints True
print_bool(False) # Prints False
print_bool(tf.constant(True)) # Prints True
print_bool(tf.constant(False)) # Prints False

print('')
print('graph:')
print_bool_graph(True) # Prints True
print_bool_graph(False) # Prints False
print_bool_graph(tf.constant(True)) # Prints True
print_bool_graph(tf.constant(False)) # Prints False

I can see there are some cases where defaults are used. The right way to handle that is to implement the default handling upstream in the first call() method. A lesser way would be to implement it as:

def cast_bool_to_primitive(x, default_value=False):
  if x is None:
    return default_value
  return x

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 19 (15 by maintainers)

Most upvoted comments

Ok, thanks for the hints, I will review that part! I should be able to start working on it from tomorrow or Thursday.

Hi @jplu, Iโ€™m sorry, but I doubt #5468 will fix the issue. Fundamentally speaking casting to primitives is not a good practice in TensorFlow, as it invalidates the use of @tf.function and is generally unnecessary as described above. Casting to primitives is, in my experience, just never the correct solution in TensorFlow.

I do think #5468 mitigates the issue, which is maybe where the confusion is coming from. This is because, the models will now correctly default to the config object when output_hidden_states=True is not specified as an input. In those cases object property is never cast to a tensor to begin with, therefore the @tf.function graph will be statically compiled to always output the hidden_states, as intended.

However, the behavior is different when output_hidden_states=True is specified as an input, as it will be cast to a Tensor when it becomes part of the inputs argument in call(). After that, it is not possible to convert back to a primitive, as that invalidates @tf.function.

If you insist on keeping it as a primitive, the best solution might be to specify it as an aux-input, similar to training and mask in a keras.layers.Layer, as they donโ€™t get converted the same way. Iโ€™m not familiar enough with the Keras internals to know the details here, and I think it might also be incompatible with compute_output_shape etc.

BTW, in the keras RNN layers, hidden_state is only specified in the constructor, properly because it can get a bit messy having to specify it in the inputs, but I donโ€™t see anything fundamentally wrong with specifying it in inputs.