tensorflow: Freezing network with batch norm does not work with TRT

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): 16.04
  • TensorFlow installed from (source or binary): Source
  • TensorFlow version (use command below): 1.10.1
  • Python version: 3.5
  • CUDA/cuDNN version: 9.0
  • Bazel version: N/A
  • GPU model and memory: N/A
  • Mobile device: N/A
  • Exact command to reproduce:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import tensorflow.contrib.tensorrt as trt
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.keras import backend as K

path         = "/tmp"
output_trt_pb    = os.path.join(path, "output_trt.pb")

np.random.seed(0)

X, Y = np.random.rand(1000, 100, 100, 3), np.random.rand(1000, 100, 100, 16)

with K.get_session() as sess:
    
    inp = tf.keras.layers.Input(shape=(100,100,3), name="input")
    x = tf.keras.layers.Conv2D(16, (3,3), padding="same", kernel_initializer="ones", name="conv2d", use_bias=False)(inp)
    x = tf.keras.layers.BatchNormalization(name="bn", fused=True)(x)
    x = tf.keras.layers.Activation("relu")(x)
    model = tf.keras.models.Model(inp, x)
    model.compile("adam", "mse")
    model.fit(X, Y, epochs=3, verbose=True)
    
    # fix nodes (from https://github.com/tensorflow/tensorflow/issues/3628) here doesn't help
    graph_def = sess.graph_def
    
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,                                     # The session is used to retrieve the weights
        graph_def,                                # The graph_def is used to retrieve the nodes
        [i.name[:-2] for i in model.outputs]      # The output node names are used to select the useful nodes
    )
    
    trt_graph = trt.create_inference_graph(output_graph_def, 
                                           [i.name[:-2] for i in model.outputs],
                                           max_batch_size=1,
                                           max_workspace_size_bytes= 1256 << 20,
                                           precision_mode="FP32") 

Describe the problem

When I freeze a protobuf that contains batch normalization and then try to use it with TRT, it fails with the error

InvalidArgumentError: Input 0 of node bn/cond/ReadVariableOp/Switch_1 was passed float from bn/gamma_1:0 incompatible with expected resource.

It seems like this is an issue in other threads, like https://github.com/tensorflow/tensorflow/issues/3628 and in my code I tried to include the suggested fixes, but this does not help. I tried using both fused=True and fused=False, and also tried trainable=False/True in BatchNormalization. If you comment out x = tf.keras.layers.BatchNormalization(name="bn", fused=True)(x) then everything works fine.

Any comment would be appreciated.

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Reactions: 5
  • Comments: 16 (6 by maintainers)

Most upvoted comments

I solved this similar issue, in my case, I added tf.keras.backend.set_learning_phase(0).

Don’t forget to clear the session before setting the learning phase to 0 (which puts keras in inference mode, so the batch norm won’t be updated)!

from keras import backend as K

K.clear_session()
K.set_learning_phase(0)

with K.get_session() as sess:
    ...

@samikama This is not a solution as tf.keras.backend.set_learning_phase(0) puts the batch norm layers in inference mode and thus the will not update during training @bhavana3 @ardianumam @fferroni A workaround is to save the weights of the trained model, clear the session, recreate the model this time with tf.keras.backend.set_learning_phase(0) and then load the weights back in before freezing. https://github.com/tensorflow/tensorflow/issues/31331#issuecomment-518655879