tensorflow: BatchNormalization doesn't work in graph mode in tf2

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow:
  • OS: MacOS
  • TensorFlow version (use command below): ‘2.0.0-rc0’
  • Python version: 3.7

tf.keras.layers.BatchNormalization layer does not work in graph mode

code to reproduce:

import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
import numpy as np

keras = tf.keras

class check_bn_model(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.bn = BatchNormalization()
    
    @tf.function
    def call(self, x):
        x = self.bn(x)
        return x
    
X = np.ones((10,5)).astype('float32')
model = check_bn_model()
model.compile('adam', 'mse')
model.fit(X, X, batch_size=2, epochs=2)

Error stack:

InaccessibleTensorError                   Traceback (most recent call last)
<ipython-input-142-0234ba2796e1> in <module>
     12 model = check_bn_model()
     13 model.compile('adam', 'mse')
---> 14 model.fit(X, X, batch_size=2, epochs=2)

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    732         max_queue_size=max_queue_size,
    733         workers=workers,
--> 734         use_multiprocessing=use_multiprocessing)
    735 
    736   def evaluate(self,

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
    222           validation_data=validation_data,
    223           validation_steps=validation_steps,
--> 224           distribution_strategy=strategy)
    225 
    226       total_samples = _get_total_number_of_samples(training_data_adapter)

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_training_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, steps_per_epoch, validation_split, validation_data, validation_steps, shuffle, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    545         max_queue_size=max_queue_size,
    546         workers=workers,
--> 547         use_multiprocessing=use_multiprocessing)
    548     val_adapter = None
    549     if validation_data:

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    591         batch_size=batch_size,
    592         check_steps=False,
--> 593         steps=steps)
    594   adapter = adapter_cls(
    595       x,

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2382     # First, we build the model on the fly if necessary.
   2383     if not self.inputs:
-> 2384       all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
   2385       is_build_called = True
   2386     else:

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _build_model_with_inputs(self, inputs, targets)
   2585     else:
   2586       cast_inputs = inputs
-> 2587     self._set_inputs(cast_inputs)
   2588     return processed_inputs, targets, is_dict_inputs
   2589 

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _set_inputs(self, inputs, outputs, training)
   2672           kwargs['training'] = training
   2673       try:
-> 2674         outputs = self(inputs, **kwargs)
   2675       except NotImplementedError:
   2676         # This Model or a submodel is dynamic and hasn't overridden

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    800                     not base_layer_utils.is_in_eager_or_tf_function()):
    801                   with auto_control_deps.AutomaticControlDependencies() as acd:
--> 802                     outputs = call_fn(cast_inputs, *args, **kwargs)
    803                     # Wrap Tensors in `outputs` in `tf.identity` to avoid
    804                     # circular dependencies.

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    437         # Lifting succeeded, so variables are initialized and we can run the
    438         # stateless function.
--> 439         return self._stateless_fn(*args, **kwds)
    440     else:
    441       canon_args, canon_kwds = \

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   1820     """Calls a graph function specialized to the inputs."""
   1821     graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 1822     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   1823 
   1824   @property

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _filtered_call(self, args, kwargs)
   1139          if isinstance(t, (ops.Tensor,
   1140                            resource_variable_ops.BaseResourceVariable))),
-> 1141         self.captured_inputs)
   1142 
   1143   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1228           {"PartitionedCall": gradient_name,
   1229            "StatefulPartitionedCall": gradient_name}):
-> 1230         flat_outputs = forward_function.call(ctx, args)
   1231     if isinstance(flat_outputs, ops.Operation) or flat_outputs is None:
   1232       # We only record function calls which have outputs.

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    538                 executing_eagerly=executing_eagerly,
    539                 config=config,
--> 540                 executor_type=executor_type)
    541 
    542     if executing_eagerly:

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/ops/functional_ops.py in partitioned_call(args, f, tout, executing_eagerly, config, executor_type)
    857           f=f,
    858           config_proto=config,
--> 859           executor_type=executor_type)
    860     else:
    861       outputs = gen_functional_ops.partitioned_call(

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_functional_ops.py in stateful_partitioned_call(args, Tout, f, config, config_proto, executor_type, name)
    670         "StatefulPartitionedCall", args=args, Tout=Tout, f=f, config=config,
    671                                    config_proto=config_proto,
--> 672                                    executor_type=executor_type, name=name)
    673   _result = _op.outputs[:]
    674   if not _result:

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    791         op = g.create_op(op_type_name, inputs, dtypes=None, name=scope,
    792                          input_types=input_types, attrs=attr_protos,
--> 793                          op_def=op_def)
    794       return output_structure, op_def.is_stateful, op
    795 

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in create_op(***failed resolving arguments***)
    542       if ctxt is not None and hasattr(ctxt, "AddValue"):
    543         inp = ctxt.AddValue(inp)
--> 544       inp = self.capture(inp)
    545       inputs[i] = inp
    546     return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access

~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in capture(self, tensor, name)
    601               " explicit Python locals or TensorFlow collections to access"
    602               " it. Defined in: %s; accessed from: %s.\n"
--> 603               % (tensor, tensor.graph, self))
    604         inner_graph = inner_graph.outer_graph
    605       return self._capture_helper(tensor, name)

InaccessibleTensorError: The tensor 'Tensor("batch_normalization_194/batch_normalization_194_trainable:0", dtype=bool)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=call, id=5982930640); accessed from: FuncGraph(name=keras_graph, id=5269648784).

If you remove the @tf.function decorator or if you pass dynamic=True to the model instantialization it will work. Otherwise it fails. Note that this is specific to BatchNormalization, if you replace it with any other layer it will work (even other normalization layers like layer norm/instancenorm)

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 17 (6 by maintainers)

Most upvoted comments

Any updates/news on this issue ?

As a workaround you can set the class variable _USE_V2_BEHAVIOR to false.

import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
BatchNormalization._USE_V2_BEHAVIOR = False
import numpy as np

keras = tf.keras


class check_bn_model(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.bn = BatchNormalization()

    @tf.function
    def call(self, x, training=None, mask=None):
        x = self.bn(x)
        return x


X = np.ones((10, 5)).astype('float32')
model = check_bn_model()
model.compile('adam', 'mse')
model.fit(X, X, batch_size=2, epochs=2)

seems the issue only happens when custom layers or subclassed models use batch norm. if batch norm is used as a standalone layer, it works @jvishnuvardhan @robieta @mshlis @ravikyram

Getting the same problem.

Will this issue be fixed in tf2.1?

Just checked and this still appears in 2.1