tensorflow: BatchNormalization doesn't work in graph mode in tf2
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow:
- OS: MacOS
- TensorFlow version (use command below): ‘2.0.0-rc0’
- Python version: 3.7
tf.keras.layers.BatchNormalization layer does not work in graph mode
code to reproduce:
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
import numpy as np
keras = tf.keras
class check_bn_model(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.bn = BatchNormalization()
@tf.function
def call(self, x):
x = self.bn(x)
return x
X = np.ones((10,5)).astype('float32')
model = check_bn_model()
model.compile('adam', 'mse')
model.fit(X, X, batch_size=2, epochs=2)
Error stack:
InaccessibleTensorError Traceback (most recent call last)
<ipython-input-142-0234ba2796e1> in <module>
12 model = check_bn_model()
13 model.compile('adam', 'mse')
---> 14 model.fit(X, X, batch_size=2, epochs=2)
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
732 max_queue_size=max_queue_size,
733 workers=workers,
--> 734 use_multiprocessing=use_multiprocessing)
735
736 def evaluate(self,
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
222 validation_data=validation_data,
223 validation_steps=validation_steps,
--> 224 distribution_strategy=strategy)
225
226 total_samples = _get_total_number_of_samples(training_data_adapter)
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_training_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, steps_per_epoch, validation_split, validation_data, validation_steps, shuffle, distribution_strategy, max_queue_size, workers, use_multiprocessing)
545 max_queue_size=max_queue_size,
546 workers=workers,
--> 547 use_multiprocessing=use_multiprocessing)
548 val_adapter = None
549 if validation_data:
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
591 batch_size=batch_size,
592 check_steps=False,
--> 593 steps=steps)
594 adapter = adapter_cls(
595 x,
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
2382 # First, we build the model on the fly if necessary.
2383 if not self.inputs:
-> 2384 all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
2385 is_build_called = True
2386 else:
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _build_model_with_inputs(self, inputs, targets)
2585 else:
2586 cast_inputs = inputs
-> 2587 self._set_inputs(cast_inputs)
2588 return processed_inputs, targets, is_dict_inputs
2589
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _set_inputs(self, inputs, outputs, training)
2672 kwargs['training'] = training
2673 try:
-> 2674 outputs = self(inputs, **kwargs)
2675 except NotImplementedError:
2676 # This Model or a submodel is dynamic and hasn't overridden
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
800 not base_layer_utils.is_in_eager_or_tf_function()):
801 with auto_control_deps.AutomaticControlDependencies() as acd:
--> 802 outputs = call_fn(cast_inputs, *args, **kwargs)
803 # Wrap Tensors in `outputs` in `tf.identity` to avoid
804 # circular dependencies.
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
437 # Lifting succeeded, so variables are initialized and we can run the
438 # stateless function.
--> 439 return self._stateless_fn(*args, **kwds)
440 else:
441 canon_args, canon_kwds = \
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
1820 """Calls a graph function specialized to the inputs."""
1821 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 1822 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
1823
1824 @property
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _filtered_call(self, args, kwargs)
1139 if isinstance(t, (ops.Tensor,
1140 resource_variable_ops.BaseResourceVariable))),
-> 1141 self.captured_inputs)
1142
1143 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1228 {"PartitionedCall": gradient_name,
1229 "StatefulPartitionedCall": gradient_name}):
-> 1230 flat_outputs = forward_function.call(ctx, args)
1231 if isinstance(flat_outputs, ops.Operation) or flat_outputs is None:
1232 # We only record function calls which have outputs.
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in call(self, ctx, args, cancellation_manager)
538 executing_eagerly=executing_eagerly,
539 config=config,
--> 540 executor_type=executor_type)
541
542 if executing_eagerly:
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/ops/functional_ops.py in partitioned_call(args, f, tout, executing_eagerly, config, executor_type)
857 f=f,
858 config_proto=config,
--> 859 executor_type=executor_type)
860 else:
861 outputs = gen_functional_ops.partitioned_call(
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_functional_ops.py in stateful_partitioned_call(args, Tout, f, config, config_proto, executor_type, name)
670 "StatefulPartitionedCall", args=args, Tout=Tout, f=f, config=config,
671 config_proto=config_proto,
--> 672 executor_type=executor_type, name=name)
673 _result = _op.outputs[:]
674 if not _result:
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
791 op = g.create_op(op_type_name, inputs, dtypes=None, name=scope,
792 input_types=input_types, attrs=attr_protos,
--> 793 op_def=op_def)
794 return output_structure, op_def.is_stateful, op
795
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in create_op(***failed resolving arguments***)
542 if ctxt is not None and hasattr(ctxt, "AddValue"):
543 inp = ctxt.AddValue(inp)
--> 544 inp = self.capture(inp)
545 inputs[i] = inp
546 return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
~/anaconda3/envs/tf_2.x/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in capture(self, tensor, name)
601 " explicit Python locals or TensorFlow collections to access"
602 " it. Defined in: %s; accessed from: %s.\n"
--> 603 % (tensor, tensor.graph, self))
604 inner_graph = inner_graph.outer_graph
605 return self._capture_helper(tensor, name)
InaccessibleTensorError: The tensor 'Tensor("batch_normalization_194/batch_normalization_194_trainable:0", dtype=bool)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=call, id=5982930640); accessed from: FuncGraph(name=keras_graph, id=5269648784).
If you remove the @tf.function decorator or if you pass dynamic=True to the model instantialization it will work. Otherwise it fails. Note that this is specific to BatchNormalization, if you replace it with any other layer it will work (even other normalization layers like layer norm/instancenorm)
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 17 (6 by maintainers)
Any updates/news on this issue ?
As a workaround you can set the class variable _USE_V2_BEHAVIOR to false.
seems the issue only happens when custom layers or subclassed models use batch norm. if batch norm is used as a standalone layer, it works @jvishnuvardhan @robieta @mshlis @ravikyram
Getting the same problem.
Just checked and this still appears in 2.1