tensorflow: tf.shape output is wrong when net input shape is changed during import

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 7
  • TensorFlow installed from (source or binary): pip
  • TensorFlow version (use command below): 1.8.0 (still present in 2.6.0, updating the code accordingly)
  • Python version: 3.6.6
  • CUDA/cuDNN version: N/A
  • GPU model and memory: N/A
  • Bazel version: N/A
  • Mobile device: N/A
  • Exact command to reproduce: see below

Describe the problem

tf.shape returns an inconsistent result when a network is imported from file and its input is changed during the import. Let me create a simple net with a batch_size of 128, and save it to disk

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

batch_size = 128
x = tf.placeholder(tf.float32, shape=(batch_size, 10), name='x')
b = tf.Variable(tf.zeros((10)))
y = tf.add(x, b, name='y')

saver = tf.train.Saver()
with tf.Session() as sess:
  tf.global_variables_initializer().run()
  saver.save(sess, './foo')

Later, I reload this model, and replace the input placeholder with a more flexible one, with an undefined batch_size.

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

x = tf.placeholder(tf.float32, shape=(None, 10))
restorer = tf.train.import_meta_graph('./foo.meta', input_map={'x:0': x})
y = tf.get_default_graph().get_tensor_by_name('y:0')
y_shape = tf.shape(y)
sess = tf.Session()
restorer.restore(sess, './foo')
[y_, y_shape_] = sess.run(['y:0', y_shape], {x: np.zeros((1, 10), np.float32)})
assert np.all(y_.shape == y_shape_), 'inconsistent sizes'

This results in an AssertionError: inconsistent sizes, because y_shape_ still returns the old batch size of 128, despite the output y being computed as expected with a batch size of 1.

About this issue

  • Original URL
  • State: open
  • Created 6 years ago
  • Reactions: 3
  • Comments: 18 (10 by maintainers)

Most upvoted comments

I think the following script should remove the _output_shapes attributes from the saved model:

import sys

import tensorflow as tf

metagraph_file = sys.argv[1]
metagraph = tf.MetaGraphDef()

with open(metagraph_file) as f:
  metagraph.ParseFromString(f.read())

for node in metagraph.graph_def.node:
  if "_output_shapes" in node.attr:
    del node.attr["_output_shapes"]

with open(metagraph_file, "w") as f:
  f.write(metagraph.SerializeToString())

In your example above, you should pass ./foo.meta as the argument (this is the serialized MetaGraphDef proto).

@feihugis

Looking at the code, @skye is going to be better versed in this than I am, but here’s a guess at what needs doing:

Right now we have a ShapeRefiner, which calls SetShape on nodes: https://github.com/tensorflow/tensorflow/blob/038675489fbf35542f6b0e1b6192876a83e8c5b6/tensorflow/core/graph/graph_constructor.cc#L623

That decodes the shape proto and sets the node’s shape. It doesn’t run a full pass of shape inference using shape functions, it just loads the saved shapes. This is only valid if the input_map targets have shapes which are compatible with the existing shapes.

So one possibility is if we have an input_map we topologically sort the graph and propagate the input_map target shapes using shape functions, ignoring the saved shapes. To satisfy the requests in this thread, we’d propagate the new shapes even if they’re incompatible with the old shapes.