tensorflow: TF2 keras.models.load_model fails with custom metrics (both h5 and tf format)
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.13.6
- TensorFlow installed from (source or binary): pip install tensorflow==2.0.0-beta1
- TensorFlow version (use command below): v2.0.0-rc2-26-g64c3d382ca 2.0.0
- Python version: v3.6.7:6ec5cf24b7, Oct 20 2018, 03:02:14
Describe the current behavior
I have a custom metric in my model and using tf.keras.models.load_model
with compile=True
after saving it results in an error in almost all cases, whereas I use the custom_objects
argument according to the documentation.
I tried to pass my custom metric with two strategies: by passing a custom function custom_accuracy
to the tf.keras.Model.compile
method, or by subclassing the MeanMetricWrapper
class and giving an instance of my subclass named CustomAccuracy
to tf.keras.Model.compile
.
I also tried the two different saving format available: h5 and tf. Here are my results:
- with tf format:
- with custom function:
fail with
ValueError
messageUnknown metric function:custom_accuracy
- with subclassed metric:
fail with
ValueError
messageUnknown metric function: CustomAccuracy
- with custom function:
fail with
- with h5 format:
- with custom function: success
- with subclassed metric:
fail with
TypeError
messagemust be str, not ABCMeta
Note that given the complete error logs (see below), the error with h5 format and subclassed metric is in fact the same as the error with the tf format. The TypeError
occurs when the code tries to raise the ValueError
.
Describe the expected behavior
This should not fail in any case, except if I am using the custom_objects
argument wrong. The documentation could be a little expanded on that matter by the way.
Code to reproduce the issue
import tensorflow as tf
print('Using Tensorflow version {} (git version {})'.format(tf.version.VERSION, tf.version.GIT_VERSION))
from tensorflow.python.keras.metrics import MeanMetricWrapper
from tensorflow.python.keras.metrics import accuracy
def custom_accuracy(y_true, y_pred):
return accuracy(y_true, y_pred)
class CustomAccuracy(MeanMetricWrapper):
def __init__(self, **kwargs):
super(CustomAccuracy, self).__init__(custom_accuracy, **kwargs)
def make_model():
inp = tf.keras.Input(shape=(2,))
x = tf.keras.layers.Dense(4)(inp)
return tf.keras.Model(inp, x)
for save_format in ['tf', 'h5']:
print("\nTrying with save_format='{}':\n".format(save_format))
model_with_function = make_model()
model_with_function.compile(loss='mse', metrics=[custom_accuracy])
model_with_function.save('/tmp/model_with_function' + '.' + save_format,
save_format=save_format)
try:
new_model = tf.keras.models.load_model('/tmp/model_with_function' + '.' + save_format,
custom_objects={'custom_accuracy': custom_accuracy},
compile=True)
print("model_with_function loaded with the following metrics:")
print(new_model.metrics)
except Exception as e:
print("model_with_function not loaded with the following error:")
print(type(e))
print(e)
model_with_subclass = make_model()
model_with_subclass.compile(loss='mse', metrics=[CustomAccuracy()])
model_with_subclass.save('/tmp/model_with_subclass' + '.' + save_format,
save_format=save_format)
try:
new_model = tf.keras.models.load_model('/tmp/model_with_subclass' + '.' + save_format,
custom_objects={'CustomAccuracy': CustomAccuracy},
compile=True)
print("model_with_subclass loaded with the following metrics:")
print(new_model.metrics)
except Exception as e:
print("model_with_subclass not loaded with the following error:")
print(type(e))
print(e)
Other info / logs
The logs are the same in the 3 error cases (to get them with the code above, just add raise
at the end of the except
blocks):
<ipython-input-14-ac0a72b492dc> in <module>
48 new_model = tf.keras.models.load_model('/tmp/model_with_subclass' + '.' + save_format,
49 custom_objects={'CustomAccuracy': CustomAccuracy},
---> 50 compile=True)
51 print("model_with_function loaded with the following metrics:")
52 print(new_model.metrics)
/path/to/tensorflow_core/python/keras/saving/save.py in load_model(filepath, custom_objects, compile)
148 if isinstance(filepath, six.string_types):
149 loader_impl.parse_saved_model(filepath)
--> 150 return saved_model_load.load(filepath, compile)
151
152 raise IOError(
/path/to/tensorflow_core/python/keras/saving/saved_model/load.py in load(path, compile)
91 if model._training_config is not None: # pylint: disable=protected-access
92 model.compile(**saving_utils.compile_args_from_training_config(
---> 93 model._training_config)) # pylint: disable=protected-access
94
95 return model
/path/to/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
/path/to/tensorflow_core/python/keras/engine/training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, distribute, **kwargs)
354 with K.get_graph().as_default():
355 # Save all metric attributes per output of the model.
--> 356 self._cache_output_metric_attributes(metrics, weighted_metrics)
357
358 # Set metric attributes on model.
/path/to/tensorflow_core/python/keras/engine/training.py in _cache_output_metric_attributes(self, metrics, weighted_metrics)
1899 output_shapes.append(output.shape.as_list())
1900 self._per_output_metrics = training_utils.collect_per_output_metric_info(
-> 1901 metrics, self.output_names, output_shapes, self.loss_functions)
1902 self._per_output_weighted_metrics = (
1903 training_utils.collect_per_output_metric_info(
/path/to/tensorflow_core/python/keras/engine/training_utils.py in collect_per_output_metric_info(metrics, output_names, output_shapes, loss_fns, is_weighted)
811 metrics_dict = OrderedDict()
812 for metric in metrics:
--> 813 metric_name = get_metric_name(metric, is_weighted)
814 metric_fn = get_metric_function(
815 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
/path/to/tensorflow_core/python/keras/engine/training_utils.py in get_metric_name(metric, weighted)
985 return metric
986
--> 987 metric = metrics_module.get(metric)
988 return metric.name if hasattr(metric, 'name') else metric.__name__
989 else:
/path/to/tensorflow_core/python/keras/metrics.py in get(identifier)
2855 def get(identifier):
2856 if isinstance(identifier, dict):
-> 2857 return deserialize(identifier)
2858 elif isinstance(identifier, six.string_types):
2859 return deserialize(str(identifier))
/path/to/tensorflow_core/python/keras/metrics.py in deserialize(config, custom_objects)
2849 module_objects=globals(),
2850 custom_objects=custom_objects,
-> 2851 printable_module_name='metric function')
2852
2853
/path/to/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
178 config = identifier
179 (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 180 config, module_objects, custom_objects, printable_module_name)
181
182 if hasattr(cls, 'from_config'):
/path/to/tensorflow_core/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
163 cls = module_objects.get(class_name)
164 if cls is None:
--> 165 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
166 return (cls, config['config'])
167
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 2
- Comments: 17 (9 by maintainers)
Hello! Was this ever solved for saving/loading custom metrics in SavedModel format opposed to .h5?
I had the same issue, only my error was:
I had subclassed
MeanMetricWrapper
, so it couldn’t possibly have been a lack of implementingget_config
andfrom_config
, and I had already made up thecustom_objects
dict which had:Everything was referenced correctly in the main script (model would run manually and through hyperparameter searches), but I kept getting this error whenever I tried loading the saved TF model. I then switched to saving/loading an H5 model instead, and got an error stating that
MeanAbsoluteScaledErrorMetric
wasn’t included incustom_objects
. So I updated it to:…and then it worked. I then switched back to the TF model and it kept working. So in the end, I suppose somewhere in the loader it’s not respecting the key/value relationship in
custom_objects
and only looking for the class name in the keys.I have seen your gist, and after installing tf-nightly I have been able to replicate it on my laptop, thank you.
The only small difference I see is that locally I have an additional warning:
WARNING: Logging before flag parsing goes to stderr.