tensorflow: TF2 keras.models.load_model fails with custom metrics (both h5 and tf format)

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS 10.13.6
  • TensorFlow installed from (source or binary): pip install tensorflow==2.0.0-beta1
  • TensorFlow version (use command below): v2.0.0-rc2-26-g64c3d382ca 2.0.0
  • Python version: v3.6.7:6ec5cf24b7, Oct 20 2018, 03:02:14

Describe the current behavior

I have a custom metric in my model and using tf.keras.models.load_model with compile=True after saving it results in an error in almost all cases, whereas I use the custom_objects argument according to the documentation.

I tried to pass my custom metric with two strategies: by passing a custom function custom_accuracy to the tf.keras.Model.compile method, or by subclassing the MeanMetricWrapper class and giving an instance of my subclass named CustomAccuracy to tf.keras.Model.compile.

I also tried the two different saving format available: h5 and tf. Here are my results:

  • with tf format:
    • with custom function: fail with ValueError message Unknown metric function:custom_accuracy
    • with subclassed metric: fail with ValueError message Unknown metric function: CustomAccuracy
  • with h5 format:
    • with custom function: success
    • with subclassed metric: fail with TypeError message must be str, not ABCMeta

Note that given the complete error logs (see below), the error with h5 format and subclassed metric is in fact the same as the error with the tf format. The TypeError occurs when the code tries to raise the ValueError.

Describe the expected behavior

This should not fail in any case, except if I am using the custom_objects argument wrong. The documentation could be a little expanded on that matter by the way.

Code to reproduce the issue

import tensorflow as tf

print('Using Tensorflow version {} (git version {})'.format(tf.version.VERSION, tf.version.GIT_VERSION))

from tensorflow.python.keras.metrics import MeanMetricWrapper
from tensorflow.python.keras.metrics import accuracy

def custom_accuracy(y_true, y_pred):
    return accuracy(y_true, y_pred)

class CustomAccuracy(MeanMetricWrapper):

    def __init__(self, **kwargs):
        super(CustomAccuracy, self).__init__(custom_accuracy, **kwargs)
        
def make_model():
    inp = tf.keras.Input(shape=(2,))
    x = tf.keras.layers.Dense(4)(inp)
    return tf.keras.Model(inp, x)


for save_format in ['tf', 'h5']:
    
    print("\nTrying with save_format='{}':\n".format(save_format))
    
    model_with_function = make_model()
    model_with_function.compile(loss='mse', metrics=[custom_accuracy])
    model_with_function.save('/tmp/model_with_function' + '.' + save_format, 
                             save_format=save_format)
    
    try:
        new_model = tf.keras.models.load_model('/tmp/model_with_function' + '.' + save_format, 
                                               custom_objects={'custom_accuracy': custom_accuracy}, 
                                               compile=True)
        print("model_with_function loaded with the following metrics:")
        print(new_model.metrics)
    except Exception as e:
        print("model_with_function not loaded with the following error:")
        print(type(e))
        print(e)
        
    model_with_subclass = make_model()
    model_with_subclass.compile(loss='mse', metrics=[CustomAccuracy()])
    model_with_subclass.save('/tmp/model_with_subclass' + '.' + save_format, 
                             save_format=save_format)
    
    try:
        new_model = tf.keras.models.load_model('/tmp/model_with_subclass' + '.' + save_format, 
                                               custom_objects={'CustomAccuracy': CustomAccuracy}, 
                                               compile=True)
        print("model_with_subclass loaded with the following metrics:")
        print(new_model.metrics)
    except Exception as e:
        print("model_with_subclass not loaded with the following error:")
        print(type(e))
        print(e)

Other info / logs The logs are the same in the 3 error cases (to get them with the code above, just add raiseat the end of the except blocks):

<ipython-input-14-ac0a72b492dc> in <module>
     48         new_model = tf.keras.models.load_model('/tmp/model_with_subclass' + '.' + save_format, 
     49                                                custom_objects={'CustomAccuracy': CustomAccuracy},
---> 50                                                compile=True)
     51         print("model_with_function loaded with the following metrics:")
     52         print(new_model.metrics)

/path/to/tensorflow_core/python/keras/saving/save.py in load_model(filepath, custom_objects, compile)
    148   if isinstance(filepath, six.string_types):
    149     loader_impl.parse_saved_model(filepath)
--> 150     return saved_model_load.load(filepath, compile)
    151 
    152   raise IOError(

/path/to/tensorflow_core/python/keras/saving/saved_model/load.py in load(path, compile)
     91     if model._training_config is not None:  # pylint: disable=protected-access
     92       model.compile(**saving_utils.compile_args_from_training_config(
---> 93           model._training_config))  # pylint: disable=protected-access
     94 
     95   return model

/path/to/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

/path/to/tensorflow_core/python/keras/engine/training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, distribute, **kwargs)
    354     with K.get_graph().as_default():
    355       # Save all metric attributes per output of the model.
--> 356       self._cache_output_metric_attributes(metrics, weighted_metrics)
    357 
    358       # Set metric attributes on model.

/path/to/tensorflow_core/python/keras/engine/training.py in _cache_output_metric_attributes(self, metrics, weighted_metrics)
   1899         output_shapes.append(output.shape.as_list())
   1900     self._per_output_metrics = training_utils.collect_per_output_metric_info(
-> 1901         metrics, self.output_names, output_shapes, self.loss_functions)
   1902     self._per_output_weighted_metrics = (
   1903         training_utils.collect_per_output_metric_info(

/path/to/tensorflow_core/python/keras/engine/training_utils.py in collect_per_output_metric_info(metrics, output_names, output_shapes, loss_fns, is_weighted)
    811     metrics_dict = OrderedDict()
    812     for metric in metrics:
--> 813       metric_name = get_metric_name(metric, is_weighted)
    814       metric_fn = get_metric_function(
    815           metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])

/path/to/tensorflow_core/python/keras/engine/training_utils.py in get_metric_name(metric, weighted)
    985       return metric
    986 
--> 987     metric = metrics_module.get(metric)
    988     return metric.name if hasattr(metric, 'name') else metric.__name__
    989   else:

/path/to/tensorflow_core/python/keras/metrics.py in get(identifier)
   2855 def get(identifier):
   2856   if isinstance(identifier, dict):
-> 2857     return deserialize(identifier)
   2858   elif isinstance(identifier, six.string_types):
   2859     return deserialize(str(identifier))

/path/to/tensorflow_core/python/keras/metrics.py in deserialize(config, custom_objects)
   2849       module_objects=globals(),
   2850       custom_objects=custom_objects,
-> 2851       printable_module_name='metric function')
   2852 
   2853 

/path/to/tensorflow_core/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    178     config = identifier
    179     (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 180         config, module_objects, custom_objects, printable_module_name)
    181 
    182     if hasattr(cls, 'from_config'):

/path/to/tensorflow_core/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    163     cls = module_objects.get(class_name)
    164     if cls is None:
--> 165       raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    166   return (cls, config['config'])
    167 

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 2
  • Comments: 17 (9 by maintainers)

Most upvoted comments

Hello! Was this ever solved for saving/loading custom metrics in SavedModel format opposed to .h5?

I had the same issue, only my error was:

Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements get_configand from_config when saving. In addition, please use the custom_objects arg when calling load_model().

I had subclassed MeanMetricWrapper, so it couldn’t possibly have been a lack of implementing get_config and from_config, and I had already made up the custom_objects dict which had:

{
    "mase_metric": MeanAbsoluteScaledErrorMetric,
    "smape_metric": SymmetricMeanAbsolutePercentErrorMetric,
}

Everything was referenced correctly in the main script (model would run manually and through hyperparameter searches), but I kept getting this error whenever I tried loading the saved TF model. I then switched to saving/loading an H5 model instead, and got an error stating that MeanAbsoluteScaledErrorMetric wasn’t included in custom_objects. So I updated it to:

{
    "MeanAbsoluteScaledErrorMetric": MeanAbsoluteScaledErrorMetric,
    "SymmetricMeanAbsolutePercentErrorMetric": SymmetricMeanAbsolutePercentErrorMetric,
    "mase_metric": MeanAbsoluteScaledErrorMetric,
    "smape_metric": SymmetricMeanAbsolutePercentErrorMetric,
}

…and then it worked. I then switched back to the TF model and it kept working. So in the end, I suppose somewhere in the loader it’s not respecting the key/value relationship in custom_objects and only looking for the class name in the keys.

I have seen your gist, and after installing tf-nightly I have been able to replicate it on my laptop, thank you.

The only small difference I see is that locally I have an additional warning: WARNING: Logging before flag parsing goes to stderr.