tensorflow: Error freezing saved model if it contains a tf.keras.layer.BatchNormalisation layer

System information

  • Have I written custom code: Yes
  • OS Platform and Distribution: Windows 1903
  • TensorFlow installed from: pip
  • TensorFlow version: 1.14.0
  • Python version: 3.7.3
  • GPU model and memory: RTX 2080 Ti

Problem

To make a frozen graph, I first create a saved model using tf.saved_model.simple_save and then freeze it using tensorflow.python.tools.freeze_graph.freeze_graph.

If a model contains some tf.keras.layers.BatchNormalisation layers, freezing will fail in TF 1.14.0 with:

ValueError: Tensor name 'batch_normalization/cond/ReadVariableOp/Switch:1' is invalid.

TF 1.13.1 does not give an error

Code to reproduce the issue

import tensorflow as tf
import tensorflow.keras.backend as K
import os
import datetime
from tensorflow.python.tools import freeze_graph
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input, Activation, BatchNormalization
from tensorflow.keras.models import Model

inputs = Input(shape=(128, 128, 1))
x = Conv2D(4, (3, 3))(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Flatten()(x)
x = Dense(5, activation='softmax')(x)
model = Model(inputs, x, name='test')
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

K.set_learning_phase(0)
save_dir = "./tmp_{:%Y-%m-%d_%H%M%S}".format(datetime.datetime.now())
tf.saved_model.simple_save(K.get_session(),
                           save_dir,
                           inputs={"input": model.inputs[0]},
                           outputs={"output": model.outputs[0]})

freeze_graph.freeze_graph(None,
                          None,
                          None,
                          None,
                          model.outputs[0].op.name,
                          None,
                          None,
                          os.path.join(save_dir, "frozen_model.pb"),
                          False,
                          "",
                          input_saved_model_dir=save_dir)

Update:

Seems to be a problem in graph_util_impl.py, in particular https://github.com/tensorflow/tensorflow/commit/0f486fc67070ba888204741c404a55a5f1a41fbc#diff-2d2827fd48cee6884e3587c901ad6952

@gargn

If I change this file back to its 1.13 version there is no more error.

Update 2:

I put K.set_learning_phase(0) before creating the model, then it works. I don’t know what effect this has though? Does it just turn off batch normalisation altogether?

Update 3

Final remarks:

  • Putting K.set_learning_phase(0) before creating the model will let it save, however the Batch Normalisation layer doesn’t seem to do anything (updating turned off?) so it is not a solution.
  • Changing graph_util_impl.py to its 1.13.1 version will let it save without error, however there will be an error ValueError: Input 0 of node batch_normalization/cond/ReadVariableOp/Switch was passed float from batch_normalization/gamma:0 incompatible with expected resource. when loading the frozen graph from the protobuf.
  • The workaround is to save the model weights, clear the session (so that tensor names are not different because of having two graphs), set learning phase to 0, recreate the model, load the weights, and then freeze (example code in my comment below)

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 1
  • Comments: 17 (2 by maintainers)

Most upvoted comments

@wenshuangwang I finally found a workaround:

  1. Create the model using a function (do not use K.set_learning_phase(0)):
def create_model():
    inputs = Input(...)
    ...
    return model

model = create_model()
  1. Train model
  2. Save weights: model.save_weights("weights.h5")
  3. Clear session and set learning phase to 0:
K.clear_session()
K.set_learning_phase(0)
  1. Recreate model and load weights:
model = create_model()
model.load_weights("weights.h5")
  1. Freeze as before

I have confirmed this works with no errors for creating the frozen model protobuf. Also there are no errors when loading the model either (using tf.import_graph_def), and it is working in my production code (using tensorflow java 1.14)

Here is full working example:

import tensorflow as tf
import tensorflow.keras.backend as K
import os
import datetime
from tensorflow.python.tools import freeze_graph
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input, Activation, BatchNormalization, Lambda
from tensorflow.keras.models import Model


def create_model():
    inputs = Input(shape=(128, 128, 1))
    x = Conv2D(4, (3, 3))(inputs)
    x = BatchNormalization()(x)
    # x = Lambda((lambda x: tf.layers.batch_normalization(x)))(x)
    x = Activation('relu')(x)
    x = Flatten()(x)
    x = Dense(5, activation='softmax')(x)
    model = Model(inputs, x, name='test')
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model


model = create_model()

# Training goes here...

model.save_weights("weights.h5")

K.clear_session()
K.set_learning_phase(0)

model = create_model()
model.load_weights("weights.h5")

save_dir = "./tmp_{:%Y-%m-%d_%H%M%S}".format(datetime.datetime.now())
tf.saved_model.simple_save(K.get_session(),
                           save_dir,
                           inputs={"input": model.inputs[0]},
                           outputs={"output": model.outputs[0]})

freeze_graph.freeze_graph(None,
                          None,
                          None,
                          None,
                          model.outputs[0].op.name,
                          None,
                          None,
                          os.path.join(save_dir, "frozen_model.pb"),
                          False,
                          "",
                          input_saved_model_dir=save_dir)