tensorflow: Exception when saving custom RNN model with a constant in the call function when using SavedModel format
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Redhat 7.8.2
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary): Binary
- TensorFlow version (use command below): 2.4.1 and 2.3.1 (tested)
- Python version: 3.6.8
- Bazel version (if compiling from source): No
- GCC/Compiler version (if compiling from source): No
- CUDA/cuDNN version: No
- GPU model and memory: No
Current Behavior When saving (SavedModel format) an RNN model with a constant in the call function (like shown below), we get an exception.
Desired behaviour
We should be able to save model defined below, in a SavedModel format, just like we can save it in an h5 format.
Code to reproduce the issue:
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow.keras.backend as K
import numpy as np
class MinimalRNNCell(tfk.layers.Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states=None, constants=None, training=False, *args, **kwargs):
prev_output = states[0]
print("constants: ", constants)
h = K.dot(inputs, self.kernel) + constants[0]
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, [output]
def get_config(self):
return dict(super().get_config(), **{'units': self.units})
cell = MinimalRNNCell(32)
x = tfk.Input((None, 5), name='x')
z = tfk.Input((1), name='z')
layer = tfk.layers.RNN(cell, name='rnn')
y = layer(x, constants=[z])
model = tfk.Model(inputs=[x, z], outputs=[y])
model.compile(optimizer='adam', loss='mse')
model.save('tmp.h5') # This works ok
model_loaded = tfk.models.load_model('tmp.h5', custom_objects={'MinimalRNNCell': MinimalRNNCell})
print(model_loaded.predict([np.array([[[0,0,0,0,0]]]), np.array([[0]])])) # This works ok
model.save('tmp2') # This throws an exception
Other info The stdout from the above is
constants: (<tf.Tensor 'Placeholder_1:0' shape=(None, 1) dtype=float32>,)
constants: (<tf.Tensor 'Placeholder_1:0' shape=(None, 1) dtype=float32>,)
constants: (<tf.Tensor 'Placeholder_1:0' shape=(None, 1) dtype=float32>,)
constants: (<tf.Tensor 'Placeholder_1:0' shape=(None, 1) dtype=float32>,)
constants: (<tf.Tensor 'model_1/Cast_1:0' shape=(None, 1) dtype=float32>,)
constants: (<tf.Tensor 'model_1/Cast_1:0' shape=(None, 1) dtype=float32>,)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0.]]
constants: (<tf.Tensor 'z:0' shape=(None, 1) dtype=float32>,)
constants: (<tf.Tensor 'z:0' shape=(None, 1) dtype=float32>,)
constants: (<tf.Tensor 'constants:0' shape=(None, None, 5) dtype=float32>,)
I’ll attach the full exception here. Anyone should be able to reproduce. However, the main error is
ValueError: Dimensions must be equal, but are 32 and 5 for '{{node add}} = AddV2[T=DT_FLOAT](MatMul, constants)' with input shapes: [?,32], [?,?,5].
For some reason, the name of z changes to constants and the shape changes from the expected (None, 1) to (None, None, 5).
Any ideas appreciated.
Thanks in advance.
About this issue
- Original URL
- State: open
- Created 3 years ago
- Comments: 26 (14 by maintainers)
@mohantym you will notice that the original filed issue (at the top) was actually saving as
.h5. So changing extension to.h5is not sufficient.Specifically, with Tensorflow 2.6:
So first, that doesn’t always fix the bug.
Second, even if it is “resolved in some cases” that does not mean resolved in all cases, so this bug should stay open, otherwise other people will have random breakage and have to go through their own whole debugging, ticket filing, etc…