tensorflow: tf.import_graph_def: graph_def is invalid at node
I’ve been trying to import a frozen graph into a new program, and do a simple forward pass, but tf.import_graph_def
has been throwing a ValueError that I really can’t make sense of.
Environment info
Operating System: Ubuntu 14.04 LTS 64-bit
Installed version of CUDA and cuDNN: none
If installed from source, provide
- The commit hash (
git rev-parse HEAD
): fc9162975e52978d3af38549b570cc3cc5f0ab66 - The output of
bazel version
Build label: 0.3.0
Build target: bazel-out/local-fastbuild/bin/src/main/java/com/google/devtools/build/lib/bazel/BazelServer_deploy.jar
Build time: Fri Jun 10 11:38:23 2016 (1465558703)
Build timestamp: 1465558703
Build timestamp as int: 1465558703
Steps to reproduce
- Copy the IPython Notebook from here
- Change
sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))
tosample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b), name="sample_prediction")
- Modify the code like so:
with tf.Session(graph=graph) as session:
tf.initialize_all_variables().run()
print('Initialized')
mean_loss = 0
# code omitted (no changes)
# new code below:
saver = tf.train.Saver(tf.all_variables())
saver.save(session, '/home/me/Documents/checkpoint.ckpt', write_meta_graph=False)
tf.train.write_graph(graph.as_graph_def(), '/home/me/Documents', 'graph.pb')
- Run, and verify that
checkpoint.ckpt
andgraph.pb
have been created - Run
bazel build tensorflow/python/tools:freeze_graph && bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/home/me/Documents/graph.pb --input_checkpoint=/home/me/Documents/checkpoint.ckpt --output_graph=/home/me/Documents/frozen_graph.pb --output_node_names=sample_prediction
- Verify that
frozen_graph.pb
has been created - Create a new IPython Notebook with the following code:
from __future__ import print_function
import os
import numpy as np
import random
import string
import tensorflow as tf
from tensorflow.python.platform import gfile
import zipfile
from six.moves import range
from six.moves.urllib.request import urlretrieve
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with open('/home/me/Documents/frozen_graph.pb', "rb") as f:
graph_def.ParseFromString(f.read())
sample_prediction = tf.import_graph_def(graph_def, name="", return_elements=['sample_prediction:0'])
- Run
What have you tried?
- Originally, the graph also contained a node named
saved_sample_output
, and when I tried importing that frozen graph, the error complained aboutsaved_sample_output:0
. I tried removing the name, re-writing the checkpoint and graph files, re-freezing, and re-running the code. It then complained aboutVariable_17:0
, which, after checkinggraph.pb
, was what had originally been namedsaved_sample_output
. Other than that, I haven’t been able to find anything else out. - Checked out #616 and looked at the solutions suggested for similar errors, but my
import_graph_def
never had an input map to begin with. - Removing the name parameter, or the return_elements parameter, or both, hasn’t made a difference.
Logs or other output that would be helpful
ValueError Traceback (most recent call last)
<ipython-input-46-3423c2073e62> in <module>()
53 with open('/home/me/Documents/frozen_graph.pb', "rb") as f:
54 graph_def.ParseFromString(f.read())
---> 55 sample_prediction = tf.import_graph_def(graph_def, name="", return_elements=['sample_prediction:0'])
/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.pyc in import_graph_def(graph_def, input_map, return_elements, name, op_dict)
318 except TypeError as te:
319 raise ValueError(_InvalidNodeMessage(
--> 320 node, 'Input tensor %r %s' % (input_name, te)))
321
322 # pylint: disable=protected_access
ValueError: graph_def is invalid at node u'Assign_4': Input tensor 'Variable_17:0' Cannot convert a tensor of type float32 to an input of type float32_ref.
About this issue
- Original URL
- State: closed
- Created 8 years ago
- Reactions: 5
- Comments: 16 (6 by maintainers)
It’s possibly because you trained and freezed a model using an old version tensorflow, and then you want to import the graph using a new version one. Try to keep tensorflow version consistency. Please refer to this issue.