tensorflow: Bug: reshape shape inference for parital defined shape

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):Linux Ubuntu 16.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v1.4.0-rc1-11-g130a514 1.4.0
  • Python version: Python 3.6.3 :: Anaconda custom (64-bit)
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:
  • Exact command to reproduce: see example

Describe the problem

When the input shape for ‘tf.reshape’ is partial defined and the new shape contains a -1 for known dimensions, ‘tf.reshape’ does not predict the shape. See the example code.

Source code / logs

Second and third example have working shape inference:

def foo(*shape):
    x = tf.placeholder(tf.float32, shape)
    return tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))

print(foo(2, 3, 4, 5))  # Tensor("Reshape_8:0", shape=(2, 3, 20), dtype=float32)  # correct
print(foo(None, 3, 4, 5))  # Tensor("Reshape_9:0", shape=(?, 3, ?), dtype=float32)  # shape inference possible
print(foo(None, None, 4, 5))  # Tensor("Reshape_10:0", shape=(?, ?, ?), dtype=float32)  # shape inference possible
print(foo(2, 3, 4, None))  # Tensor("Reshape_11:0", shape=(2, 3, ?), dtype=float32)  # correct

Proof that shape inference is possible:

import functools, operator 
def bar(*shape):
    x = tf.placeholder(tf.float32, shape)
    tmp = x.shape[-2:]
    if not tmp == tf.TensorShape(None):
        tmp = functools.reduce(operator.mul, tmp, tf.Dimension(1))
    if str(tmp) == '?' or tmp == tf.TensorShape(None):
        shape = [-1]
        shape = [tmp]
    return tf.reshape(x, tf.concat([tf.shape(x)[:-2], shape], 0))

print(bar(2, 3, 4, 5))  # Tensor("Reshape_21:0", shape=(2, 3, 20), dtype=float32)
print(bar(None, 3, 4, 5))  # Tensor("Reshape_22:0", shape=(?, 3, 20), dtype=float32)
print(bar(None, None, 4, 5))  # Tensor("Reshape_23:0", shape=(?, ?, 20), dtype=float32)
print(bar(2, 3, 4, None))  # Tensor("Reshape_24:0", shape=(2, 3, ?), dtype=float32)

About this issue

  • Original URL
  • State: closed
  • Created 7 years ago
  • Comments: 15 (13 by maintainers)

Most upvoted comments

Hey, sorry for not looking at this sooner, too many issues!

One note: I’m working on switching the Python API to call the C API, which means it’ll switch to using the C++ shape inference implementation. I tried running the example with the C API enabled (by setting the env variable TF_C_API_GRAPH_CONSTRUCTION=1), and unfortunately it does even worse:

Tensor("Reshape:0", shape=(2, 3, 20), dtype=float32)
Tensor("Reshape_1:0", shape=(?, ?, ?), dtype=float32)
Tensor("Reshape_2:0", shape=(?, ?, ?), dtype=float32)
Tensor("Reshape_3:0", shape=(?, ?, ?), dtype=float32)

However, whoever takes this on should still fix the C++ shape inference instead of the Python version, since the Python version will go away soon.

I’m not super familiar with shape inference, but I would guess this is doable if someone is willing to take the time to debug the example and get familiar with the shape inference code. It’s pretty tricky, but also somewhat self-contained and easy to read, so someone familiar with C++ and debugging should be able to do it given some time, even if they’re not super familiar with TF (I can offer pointers for how to get started). I’ll probably eventually do this myself if no one else does it, but it’ll take a while for me to get to it.