tensorflow: Can't enforce shape invariants with TensorArrays in while_loop

I can’t enforce shape invariants in a while_loop if one of the inputs is a TensorArray. Here’s a minimal example:

import tensorflow as tf

def body(i,ta):
    ta = ta.write(i,1.0)
    return (i+1,ta)

arr_size = 10
ta = tf.TensorArray(tf.float32, size=arr_size)

i = tf.constant(0,tf.int32)
input = (i,ta)
cond = lambda i,_ : i < arr_size
output = tf.while_loop(cond, body,input,shape_invariants=(i.get_shape(),tf.TensorShape(arr_size)))

#works fine without shape_invariants:
#output = tf.while_loop(cond, body,input)

mat = output[1].stack()
sess = tf.InteractiveSession()
print(mat.eval())

The code above works fine if the while_loop is not fed shape_invariants. Using shape_invariants though, I get the following error:

ValueError: The shape invariant specified for TensorArray:1 is not compatible with the initial shape of the loop variable. It enters the loop with shape <unknown>, but the specified shape invariant is (10,).

Am I doing something wrong or is this a bug?

Thanks!

About this issue

  • Original URL
  • State: closed
  • Created 7 years ago
  • Comments: 19 (18 by maintainers)

Most upvoted comments

TensorArray objects have a shape invariant value of None. This is a documentation bug.

On Jun 16, 2017 11:26 AM, “Geoffrey Irving” notifications@github.com wrote:

Assigned #7700 https://github.com/tensorflow/tensorflow/issues/7700 to @ebrevdo https://github.com/ebrevdo.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/tensorflow/issues/7700#event-1127296666, or mute the thread https://github.com/notifications/unsubscribe-auth/ABtim3ASqQialP_XFtQUBe3k9LSq9ychks5sEsjIgaJpZM4MGSMB .