tensorflow: Error freezing saved model if it contains a tf.keras.layer.BatchNormalisation layer
System information
- Have I written custom code: Yes
- OS Platform and Distribution: Windows 1903
- TensorFlow installed from: pip
- TensorFlow version: 1.14.0
- Python version: 3.7.3
- GPU model and memory: RTX 2080 Ti
Problem
To make a frozen graph, I first create a saved model using tf.saved_model.simple_save and then freeze it using tensorflow.python.tools.freeze_graph.freeze_graph.
If a model contains some tf.keras.layers.BatchNormalisation layers, freezing will fail in TF 1.14.0 with:
ValueError: Tensor name 'batch_normalization/cond/ReadVariableOp/Switch:1' is invalid.
TF 1.13.1 does not give an error
Code to reproduce the issue
import tensorflow as tf
import tensorflow.keras.backend as K
import os
import datetime
from tensorflow.python.tools import freeze_graph
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input, Activation, BatchNormalization
from tensorflow.keras.models import Model
inputs = Input(shape=(128, 128, 1))
x = Conv2D(4, (3, 3))(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Flatten()(x)
x = Dense(5, activation='softmax')(x)
model = Model(inputs, x, name='test')
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
K.set_learning_phase(0)
save_dir = "./tmp_{:%Y-%m-%d_%H%M%S}".format(datetime.datetime.now())
tf.saved_model.simple_save(K.get_session(),
save_dir,
inputs={"input": model.inputs[0]},
outputs={"output": model.outputs[0]})
freeze_graph.freeze_graph(None,
None,
None,
None,
model.outputs[0].op.name,
None,
None,
os.path.join(save_dir, "frozen_model.pb"),
False,
"",
input_saved_model_dir=save_dir)
Update:
Seems to be a problem in graph_util_impl.py, in particular https://github.com/tensorflow/tensorflow/commit/0f486fc67070ba888204741c404a55a5f1a41fbc#diff-2d2827fd48cee6884e3587c901ad6952
If I change this file back to its 1.13 version there is no more error.
Update 2:
I put K.set_learning_phase(0) before creating the model, then it works. I don’t know what effect this has though? Does it just turn off batch normalisation altogether?
Update 3
Final remarks:
- Putting
K.set_learning_phase(0)before creating the model will let it save, however the Batch Normalisation layer doesn’t seem to do anything (updating turned off?) so it is not a solution. - Changing
graph_util_impl.pyto its 1.13.1 version will let it save without error, however there will be an errorValueError: Input 0 of node batch_normalization/cond/ReadVariableOp/Switch was passed float from batch_normalization/gamma:0 incompatible with expected resource.when loading the frozen graph from the protobuf. - The workaround is to save the model weights, clear the session (so that tensor names are not different because of having two graphs), set learning phase to 0, recreate the model, load the weights, and then freeze (example code in my comment below)
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 1
- Comments: 17 (2 by maintainers)
@wenshuangwang I finally found a workaround:
K.set_learning_phase(0)):model.save_weights("weights.h5")I have confirmed this works with no errors for creating the frozen model protobuf. Also there are no errors when loading the model either (using
tf.import_graph_def), and it is working in my production code (using tensorflow java 1.14)Here is full working example: