tensorflow: ModelCheckpoint behavior doesn't change when changing save_weights_only from True to False for custom tf.keras.Model

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Google Colab
  • Python version: 3.6.9
  • Tensorflow version: 2.3.0 (‘v2.3.0-0-gb36436b087’)

Describe the current behavior I am migrating from TF1 to TF2 and I am trying to understand best practices when it comes to model saving/checkpointing during training.

I would like to checkpoint the model training so as to be able to resume training later.

Quoting from the documentation (https://www.tensorflow.org/guide/keras/save_and_serialize#saving_loading_only_the_models_weights_values)

You can choose to only save & load a model's weights. This can be useful if:
* You only need the model for inference: in this case you won't need to restart training, so you don't need the compilation information or optimizer state.
* You are doing transfer learning: in this case you will be training a new model reusing the state of a prior model, so you don't need the compilation information of the prior model.

So it seems to me given my use-case of resuming training that I would want to checkpoint the full model and not just the weights correct ? More specifically I would like the following to be saved at each epoch.

* The model's architecture/config
* The model's weight values (which were learned during training)
* The model's compilation information (if compile()) was called
* The optimizer and its state, if any (this enables you to restart training where you left)

The following code snippet shown in the documentation (https://www.tensorflow.org/tutorials/keras/save_and_load#checkpoint_callback_usage) only saves the weights:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images,test_labels),
          callbacks=[cp_callback])  # Pass callback to training

So I made a single change to the above code changing save_weights_only from False to True

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=False,   # <- this is the change I made 
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images,test_labels),
          callbacks=[cp_callback])  # Pass callback to training

and now I see a cp.ckpt file instead of the .data and .index file - all good so far.

However the example in the documentation is using a tf.keras.Sequential model - and I would like to subclass tf.keras.Model instead. So I replicate the example’s model using a very simple tf.keras.Model

class SimpleModel(tf.keras.Model):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.layer_1 = keras.layers.Dense(512, activation='relu', input_shape=(784,))
    self.layer_2 = keras.layers.Dropout(0.2)
    self.layer_3 = keras.layers.Dense(10)
  
  def call(self, inputs):
    x = self.layer_1(inputs)
    x = self.layer_2(x)
    return self.layer_3(x)

and now I no longer see cp.ckpt files being created by the callback but I see .data and .index files (indicating just the weights are being saved)

It would be great if someone can explain how to go about checkpointing the full model in a SavedModel (.tf) format for the purpose of resuming training specifically for models implemented by subclassing tf.keras.Model.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 3
  • Comments: 17 (5 by maintainers)

Commits related to this issue

Most upvoted comments

@chengjianhong - my current workaround is to implement my own ModelCheckpoint - that overrides set_model

class CustomCheckpoint(tf.keras.callbacks.ModelCheckpoint):
    def set_model(self, model):
        self.model = model

@Saduf2019 - please find the colab gist to reporduce this issue (I made sure I am using tf-nightly) https://colab.research.google.com/gist/marwan116/ab3ac395d0dfe111eb0ec0853bb8e3f8/untitled1.ipynb

@marwan116 Have you solved the problem yet? I also use ‘save_best_only=True’ to save the model, but when I use ‘tf.keras.models.load_model’ to load the model file, It happens the error: “ValueError: No model found in config file.”