tensorflow: Cannot resume training using model.save and load_model()

System Information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes

  • OS Platform and Distribution: CentOS Linux 7.6.1810

  • TensorFlow installed from : binary (pip)

  • TensorFlow version : 2.2.0-rc3

  • Python version: 3.6.4

  • CUDA/cuDNN version: CUDA 10.1 / cuDNN 7.6.0

  • GPU model and memory: 2 x TitanX - 12GB

Describe the current behavior I am using Tensorflow 2.2.0 on multi-gpu system. Having the need to train large networks for several days, I save the model weights with optimizer state using model.save(). When I reload the model using tf.keras.models.load_model(), the loss spikes sharply on TensorBoard and the accuracy also shows a sudden drop. Though the loss recovers within the epoch, it does not comply with the intended behavior of saving training state using model.save().

Describe the expected behavior The API should be able to save and resume training from the very same point after loading a model from ‘.h5’ file.

Standalone code to reproduce the issue This code is a minimal reproducible example. It was tested on multi-gpu systems with 8 gpus. The re-run of the script is achieved by deleting the current model and distribute strategy and re-initializing them to simulate stop and restart of training process.

import os
import glob
import numpy as np
import tensorflow as tf
tf.__version__

gpus = tf.config.experimental.list_logical_devices('GPU')
print(gpus)

RESULT_DIR = os.path.join(os.getcwd(), 'Test', 'Results')
CHECKPOINT_FREQUENCY = 16
LOG_EVERY = 1

BATCH_SIZE_PER_GPU = 16
NUM_GPUS = len(gpus)
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS

def get_model():
    
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(filters=32, strides=1, kernel_size=(4,4), input_shape=(28,28,1)),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10)
    ])
    
    return model

class SparseCategoricalLoss(tf.keras.losses.Loss):
    
    def __init__(self, num_classes, name='SparseCategoricalLoss', from_logits=False, loss_weight=1.0, *args, **kwargs):
        
        super().__init__(*args, **kwargs)
        self.num_classes = num_classes
        self.name = name
        self.from_logits=from_logits
        self.loss_weight = loss_weight
        
    def loss_fn(self, y_true, y_pred):
        label = y_true[:,0:self.num_classes]
        logit = y_pred[:,0:self.num_classes]
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=self.from_logits,
                                                             name=self.name,
                                                             reduction=tf.keras.losses.Reduction.NONE)(label, logit)
        loss *= self.loss_weight
        return loss
    
    
    def call(self, y_true, y_pred):
        total_loss = self.loss_fn(y_true, y_pred)
        return total_loss

    def get_config(self):
         
        config = super().get_config().copy()
        config.update({
            'num_classes' : self.num_classes,
            'name' : self.name,
            'loss_weight' : self.loss_weight
        })
        return config

loss = SparseCategoricalLoss(num_classes=10,
                             from_logits=True,
                             name='categorical_loss')

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    
    model = get_model()
    
    optimizer = tf.keras.optimizers.RMSprop(
                                            learning_rate=0.001,
                                            epsilon=1.0,
                                            momentum=0.9,
                                            rho=0.9
                                           )
    
    model.compile(optimizer=optimizer, loss=loss, metrics=['acc'])

(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
X_train = np.expand_dims(X_train, 3)
X_test = np.expand_dims(X_test, 3)

class LoggingCallback(tf.keras.callbacks.Callback):

    def __init__(self, result_dir, log_every, initial_step=0, checkpoint_frequency=None, **kwargs):
        
        super().__init__(**kwargs)
        
        # Create result directory
        self.result_dir = result_dir
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        
        # create checkpoint directory
        checkpoint_dir = os.path.join(self.result_dir, 'checkpoint')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        
        # create tensorboard directory
        tensorboard_dir = os.path.join(self.result_dir, 'tensorboard')
        if not os.path.join(tensorboard_dir):
            os.makedirs(tensorboard_dir)
        
        self.log_every = log_every
        self.checkpoint_frequency = checkpoint_frequency
        self.train_writer = tf.summary.create_file_writer( os.path.join(tensorboard_dir, 'train') )
        self.step = initial_step
        
        
    # Write metrics to TensorBoard    
    def write_metrics_tensorboard(self, logs):
        with self.train_writer.as_default():
            for name, value in logs.items():
                if name in ['batch', 'size']:
                    continue
                tf.summary.scalar(name, value, step=self.step)
                
                
    def on_batch_end(self, batch, logs=None):
        
        self.step += 1
        
        # Write metrics to tensorboard
        if self.step % self.log_every == 0:
            self.write_metrics_tensorboard(logs)
            
        # Save model checkpoint (weights + optimizer state)
        if self.checkpoint_frequency and self.step % self.checkpoint_frequency == 0:
            name = 'model_step_%d.h5' % self.step
            path = os.path.join(self.result_dir, 'checkpoint', name)
            self.model.save( path )

callbacks = LoggingCallback(result_dir=RESULT_DIR, log_every=LOG_EVERY, checkpoint_frequency=CHECKPOINT_FREQUENCY)

model.fit(
          x = X_train, 
          y = Y_train, 
          batch_size=GLOBAL_BATCH_SIZE,
          epochs=7,
          validation_data = (X_test, Y_test),
          callbacks=callbacks,
          verbose=1 
         )

del model
del strategy

previous_checkpoints = glob.glob(os.path.join(RESULT_DIR, 'checkpoint', '*'))
previous_checkpoints.sort(key=lambda x : int(os.path.basename(x).split('_')[2].replace('.h5', '')) )
latest_checkpoint = previous_checkpoints[-1]
print('Found Latest Checkpoint : %s' % latest_checkpoint)
    
initial_step = int(os.path.basename(latest_checkpoint).split('_')[2].replace('.h5', ''))
print('Resuming training from step %d' % initial_step)
    
new_callback = LoggingCallback(result_dir=RESULT_DIR, log_every=LOG_EVERY, initial_step=initial_step, checkpoint_frequency=CHECKPOINT_FREQUENCY)

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = tf.keras.models.load_model( latest_checkpoint, custom_objects={'SparseCategoricalLoss':SparseCategoricalLoss} )

model.fit(
          x = X_train, 
          y = Y_train, 
          batch_size=GLOBAL_BATCH_SIZE,
          epochs=10,
          validation_data = (X_test, Y_test),
          callbacks=new_callback,
          verbose=1 
         )

Here is a link to colab showing the output : https://colab.research.google.com/gist/suraj-maniyar/1a305d7249baee4393147cb479ea2933/restart_training.ipynb

Other info / logs

The TensorBoard entry looks like this : tensorboard

This was a toy example using mnist. After about 26k steps, when the training was restarted, the loss spiked up indicating that the last saved checkpoint did not save the training configuration correctly. I am training an InceptionResNet network for several days and the spike in the loss is very concerning when I restart the training (shown below). tensorboard_inception

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 31 (11 by maintainers)

Most upvoted comments

@w4nderlust I still see the issue with SGD optimizer (gist). I had tried something similar to what you mentioned in your issue i.e. save and load optimizer state separately as pickle object, but I was still seeing the spike in loss. I wonder if tf.distribute.Strategy has a different logic to save and reload model and causing the optimizer to reset.

Still happens in 2.6.0. I read elsewhere this happens with distributed strategies. Optimizer weights are not saved in the savedmodel, and so are not loaded when restarting from the savedmodel. Nowhere is the tf docs does it warn about this. I don’t mind if it’s the case, but at least update the docs and give a workaround, this has been an issue for 2+ years.

My workaround steps are:

  1. save two checkpoints during training, one h5 model checkpoint, one weights only checkpoint
  2. when resuming, load the model from the h5, then model.load_weights() from the weights only checkpoint
  3. model.save(“path”, include_optimizer=False) after training to create a savedmodel for inference

You can print the model.optimizer.weights after this. If you try to load a savedmodel saved during distributed training you cannot print model.optimizer.weights because they don’t exist. And if you try to load a weights checkpoint into a loaded savedmodel, it will say the optimizer weights are unused.

Same issue here:

  1. trained model in TF==2.2.0. Saved in SavedModel format, and JSON and .h5.
  2. from new session, load model and continue/resume training. Out of the box, the accuracy starts from 0 and loss is high.

I had the same problem than suraj-maniyar on a custom training with custom loss. I’m using a stateful optimizer as well but without any distribution strategy. I found a satisfying solution here : https://stackoverflow.com/questions/49503748/save-and-load-model-optimizer-state.

The problem was to resume the optimizer state, which was possible without compiling the model, this solution works with Adam could be adapted to other optimizer I guess.

Hope this helps.

Same issue here as well. Tried TF2.0 and 2.2 - same thing. Quite basic scenario:

  1. Train model for x epoch
  2. Save model as json and weights separately
  3. Restart process to resume training:
    • load model from json
    • load weights
    • model.fit() Expected the ‘loss’ to be close the last best saved value (monitor = ‘loss’). Instead, model starts training like weights were not loaded at all. Strategy does not affect the result.

Note. When all of the above steps are run sequentially in a process(without restart - step #3) model will continue training as expected.

Thanks for confirming @suraj-maniyar. We will ask the folks who are expert in Keras to take a look.

@w4nderlust I still see the problem even after removing batch normalization (gist).