tensorflow: autograph fails inside keras model train_step including a for loop over a tensor
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linus Ubuntu 18.04
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: no
- TensorFlow installed from (source or binary): binary (docker image latest-gpu-py3)
- TensorFlow version (use command below): 2.3
- Python version: Python 3.6.9
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory: V100
Describe the current behavior
When writing a python “for” loop inside a tf.keras.Model.train_step I get the following error:
OperatorNotAllowedInGraphError: iterating over tf.Tensor is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
The same function works correctly when outside of a keras model but still decorated with tf.function.
Describe the expected behavior
autograph should support iterating over a tensor also inside a keras model
Standalone code to reproduce the issue
import tensorflow as tf
import numpy as np
t = tf.Variable(0)
@tf.function()
def foo():
for n in tf.range(tf.constant(10)):
t.assign_add(n)
return t
nt = foo()
nt # <tf.Tensor: shape=(), dtype=int32, numpy=45>
class mymodel(tf.keras.Model):
def __init__(self):
super().__init__()
self.t = tf.Variable(0)
def train_step(self, data):
for n in tf.range(tf.constant(10)):
t.assign_add(n)
return {"loss": t}
mm = mymodel()
mm.compile()
mm.fit(np.random.random((5)), steps_per_epoch=1) # this doesn't work see trace below
Other info / logs
OperatorNotAllowedInGraphErrorTraceback (most recent call last) <ipython-input-18-c68155fbb474> in <module> ----> 1 mm.fit(np.random.random((5)), steps_per_epoch=1)
~usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
–> 108 return method(self, *args, **kwargs)
109
110 # Running inside run_distribute_coordinator already.
~usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing) 1096 batch_size=batch_size): 1097 callbacks.on_train_batch_begin(step) -> 1098 tmp_logs = train_function(iterator) 1099 if data_handler.should_sync: 1100 context.async_wait()
~usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in call(self, *args, **kwds) 778 else: 779 compiler = “nonXla” –> 780 result = self._call(*args, **kwds) 781 782 new_tracing_count = self._get_tracing_count()
~usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds) 821 # This is the first call of call, so we have to initialize. 822 initializers = [] –> 823 self._initialize(args, kwds, add_initializers_to=initializers) 824 finally: 825 # At this point we know that the initialization is complete (or less
~usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to) 695 self._concrete_stateful_fn = ( 696 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access –> 697 *args, **kwds)) 698 699 def invalid_creator_scope(*unused_args, **unused_kwds):
~usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs) 2853 args, kwargs = None, None 2854 with self._lock: -> 2855 graph_function, _, _ = self._maybe_define_function(args, kwargs) 2856 return graph_function 2857
~usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs) 3211 3212 self._function_cache.missed.add(call_context_key) -> 3213 graph_function = self._create_graph_function(args, kwargs) 3214 self._function_cache.primary[cache_key] = graph_function 3215 return graph_function, args, kwargs
~usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes) 3073 arg_names=arg_names, 3074 override_flat_arg_shapes=override_flat_arg_shapes, -> 3075 capture_by_value=self._capture_by_value), 3076 self._function_attributes, 3077 function_spec=self.function_spec,
~usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
984 _, original_func = tf_decorator.unwrap(python_func)
985
–> 986 func_outputs = python_func(*func_args, **func_kwargs)
987
988 # invariant: func_outputs contains only Tensors, CompositeTensors,
~usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds) 598 # wrapped allows AutoGraph to swap in a converted function. We give 599 # the function a weak reference to itself to avoid a reference cycle. –> 600 return weak_wrapped_fn().wrapped(*args, **kwds) 601 weak_wrapped_fn = weakref.ref(wrapped_fn) 602
~usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs) 971 except Exception as e: # pylint:disable=broad-except 972 if hasattr(e, “ag_error_metadata”): –> 973 raise e.ag_error_metadata.to_exception(e) 974 else: 975 raise
OperatorNotAllowedInGraphError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:789 run_step **
outputs = model.train_step(data)
<ipython-input-12-62f0dcb0797d>:6 train_step
for n in tf.range(tf.constant(10)):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:503 __iter__
self._disallow_iteration()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:496 _disallow_iteration
self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
" indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
About this issue
- Original URL
- State: open
- Created 4 years ago
- Reactions: 2
- Comments: 18 (8 by maintainers)
@jvishnuvardhan
I finally found the difference between our tests. In your last gist (where it runs without the decorator) you changed the original example so that the for loop iterates over the python object range(10) rather than the original tf.range(tf.constant(10)) which is what leads to the issue and the original error I reported above. Changing back to the original example leads to the original issue. Thanks!