tensorflow: TensorArray failure with automatic gradient computation for nested scan/map_fn/fold

I drew up this simple example to show that nested While gradient computation fails in the python interface (as the gradContext for the forwardContext of the second input of the Add op is None). Toggling the comment demonstrates the issue is not with all control flow graphs.

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

OUTER_LOOP_MAX = 5
INNER_LOOP_MAX = 3

inner = tf.placeholder("float32", [1, INNER_LOOP_MAX])
outer = tf.placeholder("float32", [1, OUTER_LOOP_MAX])

X = tf.Variable(tf.zeros([1], dtype="float32"))

max = tf.constant(1)

def outer_cond_func(c1, outer_acc, outer_array):
    return tf.less(c1, OUTER_LOOP_MAX)


def outer_body_func(c1, outer_acc, outer_array):
    concat = tf.concat(0, [[0], tf.expand_dims(c1, 0)])
    slice = tf.slice(outer_array, concat, [1, 1])
    outer_num = tf.reduce_sum(slice)

    def inner_cond_func(c2, inner_acc, inner_array):
        return tf.less(c2, INNER_LOOP_MAX)

    def inner_body_func(c2, inner_acc, inner_array):
        concat2 = tf.concat(0, [[0], tf.expand_dims(c2, 0)])
        inner_num = tf.reduce_sum(tf.slice(inner_array, concat2, [1,1]))
        inner_acc += inner_num * outer_num * X

        c2 += 1
        return c2, inner_acc, inner_array

    _, inside_summed_products, _ = control_flow_ops.While(inner_cond_func, inner_body_func, [tf.constant(0), tf.constant(0.0), inner])

    def true_func():
        return 2*outer_num

    def false_func():
        return 3*outer_num

    cond_num = control_flow_ops.cond(tf.less(c1,max),true_func, false_func)
    outer_acc = tf.add(outer_acc, inside_summed_products)
    # outer_acc = tf.add(outer_acc, cond_num)
    c1 += 1
    return c1, outer_acc, outer_array

_, value, _ = control_flow_ops.While(outer_cond_func, outer_body_func, [tf.constant(0), tf.constant(0.0), outer])
control_flow_ops.switch()

loss = value * X
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)

for i in xrange(1, 10):
    feed_dict = {inner: [[1.0, 2.0, 3.0]], outer: [[4.0, 5.0, 6.0, 7.0, 8.0]]}
    print sess.run([train_op, loss], feed_dict=feed_dict)

About this issue

  • Original URL
  • State: closed
  • Created 9 years ago
  • Comments: 30 (16 by maintainers)

Most upvoted comments

Still no eta for fixing nested scan.

ACT only requires a while_loop inside the rnncell, not a scan, and does not require TensorArray except in the outer dynamic_rnn call. You should be able to implement it without running into this bug.