tensorflow: tf2.3 keras.models.load_model setting compile=False fails to load saved_model but tf2.0 works.
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Use tensorflow Addons
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Mac (Can be reproduced on colab)
- TensorFlow installed from (source or binary): pip
- TensorFlow version (use command below): 2.3 fails 2.0works
- Python version: python3
Describe the current behavior
I use F1score from addons as the metric. After training, I use keras.models.load_model
to load the saved_model and also set compile=False
. I got an error.
ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.
This happens with tf2.3, but works with tf2.0.
Describe the expected behavior
If compile=False
is set, it shouldn’t check the metrics or losses.
Standalone code to reproduce the issue
- CODE
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
print(tf.__version__)
_input = tf.keras.layers.Input(shape=(500), name="fbank") # B*T*F*c
out = tf.keras.layers.Dense(50, activation="tanh")(_input)
probabilities = tf.keras.layers.Dense(2, activation="softmax")(out)
model = tf.keras.Model(inputs=_input, outputs=probabilities)
model.compile(optimizer="sgd", loss=tf.keras.losses.CategoricalCrossentropy(),
metrics= ["accuracy", tfa.metrics.F1Score(num_classes=2, average="micro")])
model.summary()
x=np.random.rand(300,500)
y=np.random.rand(300,2)
model.fit(x,y,batch_size=100, epochs=2)
path = 'saved_model/'
model.save(path, save_format='tf')
del model
model = tf.keras.models.load_model('saved_model', compile=False)
- OUTPUT
2.3.0
Model: "functional_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
fbank (InputLayer) [(None, 500)] 0
_________________________________________________________________
dense (Dense) (None, 50) 25050
_________________________________________________________________
dense_1 (Dense) (None, 2) 102
=================================================================
Total params: 25,152
Trainable params: 25,152
Non-trainable params: 0
_________________________________________________________________
Epoch 1/2
3/3 [==============================] - 0s 4ms/step - loss: 0.7292 - accuracy: 0.5033 - f1_score: 0.0000e+00
Epoch 2/2
3/3 [==============================] - 0s 4ms/step - loss: 0.7192 - accuracy: 0.5200 - f1_score: 0.0000e+00
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/assets
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-1-9ff0edc2f186> in <module>()
23
24 del model
---> 25 model = tf.keras.models.load_model('saved_model', compile=False)
8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py in revive_custom_object(identifier, metadata)
844 'and `from_config` when saving. In addition, please use '
845 'the `custom_objects` arg when calling `load_model()`.'
--> 846 .format(identifier))
847
848
ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.
colab
https://colab.research.google.com/drive/17DI2N1L9EKSJ8-Ua88mcSnkmRT5adna3?usp=sharing
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 23 (10 by maintainers)
Can you try to load with
custom_objects={"F1Score": tfa.metrics.F1Score}
? I’ve used:!pip install --upgrade tensorflow tensorflow_addons
in your colab/cc @marload
The official document says that if we use Savedmodel, we don’t need to consider the issue of custom_object.
Model should been successfully loaded by
model =tf.keras.models.load_model('saved_model')
.Even if there is a custom_object, it should not affect the loading model.
When we compile and train a model and then distribute it to others, we hope that the model can be successfully loaded without any source data sharing and inference can be performed correctly.
If someone only wants to load the model for prediction and inference, without retraining, I found a workaround solution. Use this code before
model.save
Thanks @bhack and @amahendrakar. This does make the model load successfully.
But the key to the problem is why the same code behaves differently in the tf2.0 and tf2.3.
In addition, if I set
compile=False
, why should we still care about the custom metric?@Liu-Da I’m experiencing this problem as well. I’m using
from tensorflow.keras.callbacks import ModelCheckpoint
to save the model. Do you see an adjustment of your work-around when usingModelCheckpoint
?Was able to reproduce the issue.
Code works with TF v2.0, throws an error stating
ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements get_config and from_config when saving. In addition, please use the custom_objects arg when calling load_model().
with TF v2.3. Please find the gist of it here. Thanks!It is referencing an internal ticket. I’ve already mentioned @k-w-w