tensorflow: tf.cond throws AssertionError when computing its gradient twice

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Tested on Arch Linux and the default Google Colab env
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): Binary
  • TensorFlow version (use command below): v2.3.0-0-gb36436b087 2.3.0 (and I’ve also tried 2.4.0, same problem)
  • Python version: 3.6.9
  • Bazel version (if compiling from source): N/A
  • GCC/Compiler version (if compiling from source): N/A
  • CUDA/cuDNN version: N/A
  • GPU model and memory: N/A

Describe the current behavior Calling g.gradient twice on a tensor obtained from tf.cond call in @tf.function causes AssertionError: assert len(func_graph.outputs) == len(grads). The problem appears only if variable in the true_fn or false_fn was transformed (multiplied, exponentiated, etc.)

Describe the expected behavior TF should either:

  • Successfully calculate the gradient, or
  • Consistently fail after the first time and even if the true_fn doesn’t contain any transformation of the variable.

Standalone code to reproduce the issue

import tensorflow as tf

a = tf.Variable(tf.ones(10))

@tf.function
def test_cond(cond_var):
    with tf.GradientTape(persistent=True) as g:
         loss = tf.cond(
             cond_var > 0,
             lambda: tf.reduce_sum(1 * a),
             lambda: tf.reduce_sum(1 * a))

    gradient1 = g.gradient(loss, [a])
    gradient2 = g.gradient(loss, [a])
    return gradient1, gradient2

g1, g2 = test_cond(tf.convert_to_tensor(1))
This throws the following AssertionError:
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-27-7a87d1b88655> in <module>()
     15     return gradient1, gradient2
     16 
---> 17 g1, g2 = test_cond(tf.convert_to_tensor(1))

8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    821       # This is the first call of __call__, so we have to initialize.
    822       initializers = []
--> 823       self._initialize(args, kwds, add_initializers_to=initializers)
    824     finally:
    825       # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    695     self._concrete_stateful_fn = (
    696         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 697             *args, **kwds))
    698 
    699     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2853       args, kwargs = None, None
   2854     with self._lock:
-> 2855       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2856     return graph_function
   2857 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3211 
   3212       self._function_cache.missed.add(call_context_key)
-> 3213       graph_function = self._create_graph_function(args, kwargs)
   3214       self._function_cache.primary[cache_key] = graph_function
   3215       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3073             arg_names=arg_names,
   3074             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075             capture_by_value=self._capture_by_value),
   3076         self._function_attributes,
   3077         function_spec=self.function_spec,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    984         _, original_func = tf_decorator.unwrap(python_func)
    985 
--> 986       func_outputs = python_func(*func_args, **func_kwargs)
    987 
    988       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    598         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    599         # the function a weak reference to itself to avoid a reference cycle.
--> 600         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    601     weak_wrapped_fn = weakref.ref(wrapped_fn)
    602 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    971           except Exception as e:  # pylint:disable=broad-except
    972             if hasattr(e, "ag_error_metadata"):
--> 973               raise e.ag_error_metadata.to_exception(e)
    974             else:
    975               raise

AssertionError: in user code:

    <ipython-input-27-7a87d1b88655>:14 test_cond  *
        gradient2 = g.gradient(loss, [a])
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/backprop.py:1073 gradient  **
        unconnected_gradients=unconnected_gradients)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/imperative_grad.py:77 imperative_grad
        compat.as_str(unconnected_gradients.value))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/backprop.py:162 _gradient_function
        return grad_fn(mock_op, *out_grads)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:121 _IfGrad
        true_graph, grads, util.unique_grad_fn_name(true_graph.name))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:384 _create_grad_func
        func_graph=_CondGradFuncGraph(name, func_graph))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:986 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:383 <lambda>
        lambda: _grad_fn(func_graph, grads), [], {},
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:359 _grad_fn
        assert len(func_graph.outputs) == len(grads)

    AssertionError: 

What’s interesting, if you replace 1 * a with a inside the tf.cond call, then everything works fine. If you remove one of the g.gradient calls, it also works fine.

I’ve uploaded the code to colab as well.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 15 (9 by maintainers)

Most upvoted comments

We should throw a better error message for sure. I will also try to think of a fix.