tensorflow: tf.cond throws AssertionError when computing its gradient twice
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Tested on Arch Linux and the default Google Colab env
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
- TensorFlow installed from (source or binary): Binary
- TensorFlow version (use command below): v2.3.0-0-gb36436b087 2.3.0 (and I’ve also tried 2.4.0, same problem)
- Python version: 3.6.9
- Bazel version (if compiling from source): N/A
- GCC/Compiler version (if compiling from source): N/A
- CUDA/cuDNN version: N/A
- GPU model and memory: N/A
Describe the current behavior
Calling g.gradient
twice on a tensor obtained from tf.cond
call in @tf.function
causes AssertionError:
assert len(func_graph.outputs) == len(grads)
. The problem appears only if variable in the true_fn
or false_fn
was transformed (multiplied, exponentiated, etc.)
Describe the expected behavior TF should either:
- Successfully calculate the gradient, or
- Consistently fail after the first time and even if the
true_fn
doesn’t contain any transformation of the variable.
Standalone code to reproduce the issue
import tensorflow as tf
a = tf.Variable(tf.ones(10))
@tf.function
def test_cond(cond_var):
with tf.GradientTape(persistent=True) as g:
loss = tf.cond(
cond_var > 0,
lambda: tf.reduce_sum(1 * a),
lambda: tf.reduce_sum(1 * a))
gradient1 = g.gradient(loss, [a])
gradient2 = g.gradient(loss, [a])
return gradient1, gradient2
g1, g2 = test_cond(tf.convert_to_tensor(1))
This throws the following AssertionError:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-27-7a87d1b88655> in <module>()
15 return gradient1, gradient2
16
---> 17 g1, g2 = test_cond(tf.convert_to_tensor(1))
8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
821 # This is the first call of __call__, so we have to initialize.
822 initializers = []
--> 823 self._initialize(args, kwds, add_initializers_to=initializers)
824 finally:
825 # At this point we know that the initialization is complete (or less
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
695 self._concrete_stateful_fn = (
696 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 697 *args, **kwds))
698
699 def invalid_creator_scope(*unused_args, **unused_kwds):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2853 args, kwargs = None, None
2854 with self._lock:
-> 2855 graph_function, _, _ = self._maybe_define_function(args, kwargs)
2856 return graph_function
2857
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3211
3212 self._function_cache.missed.add(call_context_key)
-> 3213 graph_function = self._create_graph_function(args, kwargs)
3214 self._function_cache.primary[cache_key] = graph_function
3215 return graph_function, args, kwargs
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3073 arg_names=arg_names,
3074 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075 capture_by_value=self._capture_by_value),
3076 self._function_attributes,
3077 function_spec=self.function_spec,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
984 _, original_func = tf_decorator.unwrap(python_func)
985
--> 986 func_outputs = python_func(*func_args, **func_kwargs)
987
988 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
598 # __wrapped__ allows AutoGraph to swap in a converted function. We give
599 # the function a weak reference to itself to avoid a reference cycle.
--> 600 return weak_wrapped_fn().__wrapped__(*args, **kwds)
601 weak_wrapped_fn = weakref.ref(wrapped_fn)
602
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
971 except Exception as e: # pylint:disable=broad-except
972 if hasattr(e, "ag_error_metadata"):
--> 973 raise e.ag_error_metadata.to_exception(e)
974 else:
975 raise
AssertionError: in user code:
<ipython-input-27-7a87d1b88655>:14 test_cond *
gradient2 = g.gradient(loss, [a])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/backprop.py:1073 gradient **
unconnected_gradients=unconnected_gradients)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/imperative_grad.py:77 imperative_grad
compat.as_str(unconnected_gradients.value))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/backprop.py:162 _gradient_function
return grad_fn(mock_op, *out_grads)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:121 _IfGrad
true_graph, grads, util.unique_grad_fn_name(true_graph.name))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:384 _create_grad_func
func_graph=_CondGradFuncGraph(name, func_graph))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:986 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:383 <lambda>
lambda: _grad_fn(func_graph, grads), [], {},
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:359 _grad_fn
assert len(func_graph.outputs) == len(grads)
AssertionError:
What’s interesting, if you replace 1 * a
with a
inside the tf.cond
call, then everything works fine. If you remove one of the g.gradient
calls, it also works fine.
I’ve uploaded the code to colab as well.
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 15 (9 by maintainers)
We should throw a better error message for sure. I will also try to think of a fix.