tensorflow: tf.import_graph_def: graph_def is invalid at node

I’ve been trying to import a frozen graph into a new program, and do a simple forward pass, but tf.import_graph_def has been throwing a ValueError that I really can’t make sense of.

Environment info

Operating System: Ubuntu 14.04 LTS 64-bit

Installed version of CUDA and cuDNN: none

If installed from source, provide

  1. The commit hash (git rev-parse HEAD): fc9162975e52978d3af38549b570cc3cc5f0ab66
  2. The output of bazel version
Build label: 0.3.0
Build target: bazel-out/local-fastbuild/bin/src/main/java/com/google/devtools/build/lib/bazel/BazelServer_deploy.jar
Build time: Fri Jun 10 11:38:23 2016 (1465558703)
Build timestamp: 1465558703
Build timestamp as int: 1465558703

Steps to reproduce

  1. Copy the IPython Notebook from here
  2. Change sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b)) to sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b), name="sample_prediction")
  3. Modify the code like so:
with tf.Session(graph=graph) as session:
  tf.initialize_all_variables().run()
  print('Initialized')
  mean_loss = 0
  # code omitted (no changes)
  # new code below:
  saver = tf.train.Saver(tf.all_variables())
  saver.save(session, '/home/me/Documents/checkpoint.ckpt', write_meta_graph=False)
  tf.train.write_graph(graph.as_graph_def(), '/home/me/Documents', 'graph.pb')
  1. Run, and verify that checkpoint.ckpt and graph.pb have been created
  2. Run bazel build tensorflow/python/tools:freeze_graph && bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/home/me/Documents/graph.pb --input_checkpoint=/home/me/Documents/checkpoint.ckpt --output_graph=/home/me/Documents/frozen_graph.pb --output_node_names=sample_prediction
  3. Verify that frozen_graph.pb has been created
  4. Create a new IPython Notebook with the following code:
from __future__ import print_function
import os
import numpy as np
import random
import string
import tensorflow as tf
from tensorflow.python.platform import gfile
import zipfile
from six.moves import range
from six.moves.urllib.request import urlretrieve

graph = tf.Graph()
with graph.as_default():
    graph_def = tf.GraphDef()
    with open('/home/me/Documents/frozen_graph.pb', "rb") as f:
        graph_def.ParseFromString(f.read())
        sample_prediction = tf.import_graph_def(graph_def, name="", return_elements=['sample_prediction:0'])
  1. Run

What have you tried?

  1. Originally, the graph also contained a node named saved_sample_output, and when I tried importing that frozen graph, the error complained about saved_sample_output:0. I tried removing the name, re-writing the checkpoint and graph files, re-freezing, and re-running the code. It then complained about Variable_17:0, which, after checking graph.pb, was what had originally been named saved_sample_output. Other than that, I haven’t been able to find anything else out.
  2. Checked out #616 and looked at the solutions suggested for similar errors, but my import_graph_def never had an input map to begin with.
  3. Removing the name parameter, or the return_elements parameter, or both, hasn’t made a difference.

Logs or other output that would be helpful

ValueError                                Traceback (most recent call last)
<ipython-input-46-3423c2073e62> in <module>()
     53     with open('/home/me/Documents/frozen_graph.pb', "rb") as f:
     54         graph_def.ParseFromString(f.read())
---> 55         sample_prediction = tf.import_graph_def(graph_def, name="", return_elements=['sample_prediction:0'])

/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.pyc in import_graph_def(graph_def, input_map, return_elements, name, op_dict)
    318           except TypeError as te:
    319             raise ValueError(_InvalidNodeMessage(
--> 320                 node, 'Input tensor %r %s' % (input_name, te)))
    321 
    322       # pylint: disable=protected_access

ValueError: graph_def is invalid at node u'Assign_4': Input tensor 'Variable_17:0' Cannot convert a tensor of type float32 to an input of type float32_ref.

About this issue

  • Original URL
  • State: closed
  • Created 8 years ago
  • Reactions: 5
  • Comments: 16 (6 by maintainers)

Most upvoted comments

It’s possibly because you trained and freezed a model using an old version tensorflow, and then you want to import the graph using a new version one. Try to keep tensorflow version consistency. Please refer to this issue.