keras: Error when Saving model with data augmentation layer on Tensorflow 2.7
I am getting an error when trying to save a model with data augmentation layers in last tensorflow version (2.7.0).
Here is the code of data augmentation:
input_shape_rgb = (img_height, img_width, 3)
data_augmentation_rgb = tf.keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomFlip("vertical"),
layers.RandomRotation(0.5),
layers.RandomZoom(0.5),
layers.RandomContrast(0.5),
RandomColorDistortion(name='random_contrast_brightness/none'),
]
)
Now I build my model like this:
input_shape = (img_height, img_width, 3)
model = Sequential([
layers.Input(input_shape),
data_augmentation_rgb,
layers.Rescaling((1./255)),
layers.Conv2D(16, kernel_size, padding=padding, activation='relu', strides=1,
data_format='channels_last'),
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Conv2D(32, kernel_size, padding=padding, activation='relu'), # best 4
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Conv2D(64, kernel_size, padding=padding, activation='relu'), # best 3
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Conv2D(128, kernel_size, padding=padding, activation='relu'), # best 3
layers.MaxPooling2D(),
layers.BatchNormalization(),
layers.Flatten(),
layers.Dense(128, activation='relu'), # best 1
layers.Dropout(0.1),
layers.Dense(128, activation='relu'), # best 1
layers.Dropout(0.1),
layers.Dense(64, activation='relu'), # best 1
layers.Dropout(0.1),
layers.Dense(num_classes, activation = 'softmax')
])
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=metrics)
model.summary()
Then after the training is done I just make:
model.save("./")
And I’m getting this error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-84-87d3f09f8bee> in <module>()
----> 1 model.save("./")
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in
error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-
packages/tensorflow/python/saved_model/function_serialization.py in
serialize_concrete_function(concrete_function, node_ids, coder)
66 except KeyError:
67 raise KeyError(
---> 68 f"Failed to add concrete function '{concrete_function.name}' to
object-"
69 f"based SavedModel as it captures tensor {capture!r} which is
unsupported"
70 " or not reachable from root. "
KeyError: "Failed to add concrete function
'b'__inference_sequential_46_layer_call_fn_662953'' to object-based SavedModel as it
captures tensor <tf.Tensor: shape=(), dtype=resource, value=<Resource Tensor>> which
is unsupported or not reachable from root. One reason could be that a stateful
object or a variable that the function depends on is not assigned to an attribute of
the serialized trackable object (see SaveTest.test_captures_unreachable_variable)."
I inspected the reason of getting this error by changing the architecture of my model and I just found that the reason came from the data_augmentation
layer since the RandomFlip
and RandomRotation
and others are changed from layers.experimental.prepocessing.RandomFlip
to layers.RandomFlip
, but still the error appears.
Here is the main issue I posted on tensorflow repo: https://github.com/tensorflow/tensorflow/issues/53053#
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 15 (5 by maintainers)
Hi @moumed
means you are using objects that arent serializable.
As a matter of fact you have passed arguments to the layers and those cannot be saved. The proper way to do this is to create a class for the model that inherits keras.Model or a class for the whole thing as a layer that inherits keras.layers.Layer and then define
get_config
andfrom_config
methods. The documentation can be found here.A workaround is saving the model as a
.h5
model and then loading only the weights from it. Hope this helps.I have the same problem. In fact the with TF2.7, saving and loading a model seems to have some issues.
Hi @ashwinshetgaonkar, The error arises because you havent loaded the custom objects like the factors for
RandomRotation
orRandomContrast
. They cannot be saved without a config since these values are not serialized. Instead of loading the model you can try outload_weights
if you dont want to declare methods for config. This is not a workaround as all the objects you want to save have to be serialized. @moumed you can refer to the issue you raised in stackoverflow which says the same, ie, you have to use methodsget_config
andfrom_config
methods.