tensorflow: Cannot resume training using model.save and load_model()
System Information
-
Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
-
OS Platform and Distribution: CentOS Linux 7.6.1810
-
TensorFlow installed from : binary (pip)
-
TensorFlow version : 2.2.0-rc3
-
Python version: 3.6.4
-
CUDA/cuDNN version: CUDA 10.1 / cuDNN 7.6.0
-
GPU model and memory: 2 x TitanX - 12GB
Describe the current behavior
I am using Tensorflow 2.2.0 on multi-gpu system. Having the need to train large networks for several days, I save the model weights with optimizer state using model.save()
. When I reload the model using tf.keras.models.load_model()
, the loss spikes sharply on TensorBoard and the accuracy also shows a sudden drop. Though the loss recovers within the epoch, it does not comply with the intended behavior of saving training state using model.save()
.
Describe the expected behavior The API should be able to save and resume training from the very same point after loading a model from ‘.h5’ file.
Standalone code to reproduce the issue This code is a minimal reproducible example. It was tested on multi-gpu systems with 8 gpus. The re-run of the script is achieved by deleting the current model and distribute strategy and re-initializing them to simulate stop and restart of training process.
import os
import glob
import numpy as np
import tensorflow as tf
tf.__version__
gpus = tf.config.experimental.list_logical_devices('GPU')
print(gpus)
RESULT_DIR = os.path.join(os.getcwd(), 'Test', 'Results')
CHECKPOINT_FREQUENCY = 16
LOG_EVERY = 1
BATCH_SIZE_PER_GPU = 16
NUM_GPUS = len(gpus)
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS
def get_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=32, strides=1, kernel_size=(4,4), input_shape=(28,28,1)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10)
])
return model
class SparseCategoricalLoss(tf.keras.losses.Loss):
def __init__(self, num_classes, name='SparseCategoricalLoss', from_logits=False, loss_weight=1.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
self.name = name
self.from_logits=from_logits
self.loss_weight = loss_weight
def loss_fn(self, y_true, y_pred):
label = y_true[:,0:self.num_classes]
logit = y_pred[:,0:self.num_classes]
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=self.from_logits,
name=self.name,
reduction=tf.keras.losses.Reduction.NONE)(label, logit)
loss *= self.loss_weight
return loss
def call(self, y_true, y_pred):
total_loss = self.loss_fn(y_true, y_pred)
return total_loss
def get_config(self):
config = super().get_config().copy()
config.update({
'num_classes' : self.num_classes,
'name' : self.name,
'loss_weight' : self.loss_weight
})
return config
loss = SparseCategoricalLoss(num_classes=10,
from_logits=True,
name='categorical_loss')
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = get_model()
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=0.001,
epsilon=1.0,
momentum=0.9,
rho=0.9
)
model.compile(optimizer=optimizer, loss=loss, metrics=['acc'])
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
X_train = np.expand_dims(X_train, 3)
X_test = np.expand_dims(X_test, 3)
class LoggingCallback(tf.keras.callbacks.Callback):
def __init__(self, result_dir, log_every, initial_step=0, checkpoint_frequency=None, **kwargs):
super().__init__(**kwargs)
# Create result directory
self.result_dir = result_dir
if not os.path.exists(result_dir):
os.makedirs(result_dir)
# create checkpoint directory
checkpoint_dir = os.path.join(self.result_dir, 'checkpoint')
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# create tensorboard directory
tensorboard_dir = os.path.join(self.result_dir, 'tensorboard')
if not os.path.join(tensorboard_dir):
os.makedirs(tensorboard_dir)
self.log_every = log_every
self.checkpoint_frequency = checkpoint_frequency
self.train_writer = tf.summary.create_file_writer( os.path.join(tensorboard_dir, 'train') )
self.step = initial_step
# Write metrics to TensorBoard
def write_metrics_tensorboard(self, logs):
with self.train_writer.as_default():
for name, value in logs.items():
if name in ['batch', 'size']:
continue
tf.summary.scalar(name, value, step=self.step)
def on_batch_end(self, batch, logs=None):
self.step += 1
# Write metrics to tensorboard
if self.step % self.log_every == 0:
self.write_metrics_tensorboard(logs)
# Save model checkpoint (weights + optimizer state)
if self.checkpoint_frequency and self.step % self.checkpoint_frequency == 0:
name = 'model_step_%d.h5' % self.step
path = os.path.join(self.result_dir, 'checkpoint', name)
self.model.save( path )
callbacks = LoggingCallback(result_dir=RESULT_DIR, log_every=LOG_EVERY, checkpoint_frequency=CHECKPOINT_FREQUENCY)
model.fit(
x = X_train,
y = Y_train,
batch_size=GLOBAL_BATCH_SIZE,
epochs=7,
validation_data = (X_test, Y_test),
callbacks=callbacks,
verbose=1
)
del model
del strategy
previous_checkpoints = glob.glob(os.path.join(RESULT_DIR, 'checkpoint', '*'))
previous_checkpoints.sort(key=lambda x : int(os.path.basename(x).split('_')[2].replace('.h5', '')) )
latest_checkpoint = previous_checkpoints[-1]
print('Found Latest Checkpoint : %s' % latest_checkpoint)
initial_step = int(os.path.basename(latest_checkpoint).split('_')[2].replace('.h5', ''))
print('Resuming training from step %d' % initial_step)
new_callback = LoggingCallback(result_dir=RESULT_DIR, log_every=LOG_EVERY, initial_step=initial_step, checkpoint_frequency=CHECKPOINT_FREQUENCY)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.models.load_model( latest_checkpoint, custom_objects={'SparseCategoricalLoss':SparseCategoricalLoss} )
model.fit(
x = X_train,
y = Y_train,
batch_size=GLOBAL_BATCH_SIZE,
epochs=10,
validation_data = (X_test, Y_test),
callbacks=new_callback,
verbose=1
)
Here is a link to colab showing the output : https://colab.research.google.com/gist/suraj-maniyar/1a305d7249baee4393147cb479ea2933/restart_training.ipynb
Other info / logs
The TensorBoard entry looks like this :
This was a toy example using mnist. After about 26k steps, when the training was restarted, the loss spiked up indicating that the last saved checkpoint did not save the training configuration correctly.
I am training an InceptionResNet network for several days and the spike in the loss is very concerning when I restart the training (shown below).
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 31 (11 by maintainers)
@w4nderlust I still see the issue with SGD optimizer (gist). I had tried something similar to what you mentioned in your issue i.e. save and load optimizer state separately as pickle object, but I was still seeing the spike in loss. I wonder if
tf.distribute.Strategy
has a different logic to save and reload model and causing the optimizer to reset.Still happens in 2.6.0. I read elsewhere this happens with distributed strategies. Optimizer weights are not saved in the savedmodel, and so are not loaded when restarting from the savedmodel. Nowhere is the tf docs does it warn about this. I don’t mind if it’s the case, but at least update the docs and give a workaround, this has been an issue for 2+ years.
My workaround steps are:
You can print the model.optimizer.weights after this. If you try to load a savedmodel saved during distributed training you cannot print model.optimizer.weights because they don’t exist. And if you try to load a weights checkpoint into a loaded savedmodel, it will say the optimizer weights are unused.
Same issue here:
I had the same problem than suraj-maniyar on a custom training with custom loss. I’m using a stateful optimizer as well but without any distribution strategy. I found a satisfying solution here : https://stackoverflow.com/questions/49503748/save-and-load-model-optimizer-state.
The problem was to resume the optimizer state, which was possible without compiling the model, this solution works with Adam could be adapted to other optimizer I guess.
Hope this helps.
Same issue here as well. Tried TF2.0 and 2.2 - same thing. Quite basic scenario:
Note. When all of the above steps are run sequentially in a process(without restart - step #3) model will continue training as expected.
Thanks for confirming @suraj-maniyar. We will ask the folks who are expert in Keras to take a look.
@w4nderlust I still see the problem even after removing batch normalization (gist).