tensorflow: ModelCheckpoint callback can't save entire subclassed models

The keras callback ModelCheckpoint won’t save a subclassed model in SavedModel format. It only ever saves the checkpoints.

import tensorflow as tf
import numpy as np

class TestModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(64, activation='relu')
        self.dense2 = tf.keras.layers.Dense(64, activation='relu')
        self.dense3 = tf.keras.layers.Dense(10)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.dense3(x)


model = TestModel()
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
              loss='mse',      
              metrics=['mae'])  

data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))

model.fit(
    data, labels, epochs=10, batch_size=32,
    callbacks=[tf.keras.callbacks.ModelCheckpoint('model_{epoch:02d}.ckpt',
                                                  save_weights_only=False)])

It seems as if the callback forces weights_only saving here. While this makes sense if the model is in .h5 format, it doesn’t make sense for SavedModel format.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 5
  • Comments: 21 (5 by maintainers)

Most upvoted comments

@gowthamkpr Any update on this? I understand why this was the intended behavior back when subclassed models couldn’t be saved. But now that this functionality has been added in tensorflow 2.2.0, it makes sense to update the callback so that there is feature parity between subclassed models and functional api models.

I m facing the same issue as well subclass model is the only option for the keras CRF layer and as result of this issue we can’t save the best model post early stopping.

Upvote here. I would love to see this fixed.

Yes, let’s have an update on this ASAP, please.

Thanks.

Okay, actually it turns out that _is_graph_network is not about Eager execution, but about the distinction between the functional and subclass APIs.

In functional API, you would define your example network as:

dense1 = tf.keras.layers.Dense(64, activation='relu')
dense2 = tf.keras.layers.Dense(64, activation='relu')
dense3 = tf.keras.layers.Dense(10)
inputs = tf.keras.Input((32,), dtype='float32')
output = dense3(dense2(dense1(inputs)))
model = tf.keras.Model(inputs, output)

Which would have the computation graph built at instantiation, notably through the clear definition of inputs’ (and thus, all nodes’) specification.

In the subclass API however, the instantiation procedure differs (this method is called), which results in a Model with somehow reduced capabilities. To be fair, the point of this is to allow defining Model constructor subclasses which implement alternative behaviours, e.g. different train step procedures, yet remain general enough to be used in functional API.

Now in your example, I would advise writing your model architecture as a tf.keras.layers.Layer subclass, and then using the functional API to instantiate a Model out of it:

class TestModel(tf.keras.layers.Layer):

    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(64, activation='relu')
        self.dense2 = tf.keras.layers.Dense(64, activation='relu')
        self.dense3 = tf.keras.layers.Dense(10)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.dense3(x)

inputs = tf.keras.Input((32,), dtype='float32')
network = TestModel()
model = tf.keras.Model(inputs, network(inputs))