tensorflow: Freezing network with batch norm does not work with TRT
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 16.04
- TensorFlow installed from (source or binary): Source
- TensorFlow version (use command below): 1.10.1
- Python version: 3.5
- CUDA/cuDNN version: 9.0
- Bazel version: N/A
- GPU model and memory: N/A
- Mobile device: N/A
- Exact command to reproduce:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import tensorflow.contrib.tensorrt as trt
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.keras import backend as K
path = "/tmp"
output_trt_pb = os.path.join(path, "output_trt.pb")
np.random.seed(0)
X, Y = np.random.rand(1000, 100, 100, 3), np.random.rand(1000, 100, 100, 16)
with K.get_session() as sess:
inp = tf.keras.layers.Input(shape=(100,100,3), name="input")
x = tf.keras.layers.Conv2D(16, (3,3), padding="same", kernel_initializer="ones", name="conv2d", use_bias=False)(inp)
x = tf.keras.layers.BatchNormalization(name="bn", fused=True)(x)
x = tf.keras.layers.Activation("relu")(x)
model = tf.keras.models.Model(inp, x)
model.compile("adam", "mse")
model.fit(X, Y, epochs=3, verbose=True)
# fix nodes (from https://github.com/tensorflow/tensorflow/issues/3628) here doesn't help
graph_def = sess.graph_def
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
graph_def, # The graph_def is used to retrieve the nodes
[i.name[:-2] for i in model.outputs] # The output node names are used to select the useful nodes
)
trt_graph = trt.create_inference_graph(output_graph_def,
[i.name[:-2] for i in model.outputs],
max_batch_size=1,
max_workspace_size_bytes= 1256 << 20,
precision_mode="FP32")
Describe the problem
When I freeze a protobuf that contains batch normalization and then try to use it with TRT, it fails with the error
InvalidArgumentError: Input 0 of node bn/cond/ReadVariableOp/Switch_1 was passed float from bn/gamma_1:0 incompatible with expected resource.
It seems like this is an issue in other threads, like https://github.com/tensorflow/tensorflow/issues/3628 and in my code I tried to include the suggested fixes, but this does not help.
I tried using both fused=True and fused=False, and also tried trainable=False/True in BatchNormalization.
If you comment out
x = tf.keras.layers.BatchNormalization(name="bn", fused=True)(x)
then everything works fine.
Any comment would be appreciated.
About this issue
- Original URL
- State: closed
- Created 6 years ago
- Reactions: 5
- Comments: 16 (6 by maintainers)
I solved this similar issue, in my case, I added
tf.keras.backend.set_learning_phase(0).Don’t forget to clear the session before setting the learning phase to 0 (which puts keras in inference mode, so the batch norm won’t be updated)!
@samikama This is not a solution as
tf.keras.backend.set_learning_phase(0)puts the batch norm layers in inference mode and thus the will not update during training @bhavana3 @ardianumam @fferroni A workaround is to save the weights of the trained model, clear the session, recreate the model this time withtf.keras.backend.set_learning_phase(0)and then load the weights back in before freezing. https://github.com/tensorflow/tensorflow/issues/31331#issuecomment-518655879