tensorflow: Saver: Don't fail on restoring variables not present in a checkpoint.

This is a feature request.

Let’s consider a scenario where one trains multiple models and uses them in combination (like it is done in GANs).

To simplify the process of saving and restoring variables that are partly shared across the models (Pretraining Model, Training Model, Evaluation Model, Infer Model) one could instantiate the whole graph, containing all operations and variables, and save it.

Then in order to do pre-training only the subset of graph elements that is required for pretraining is used.

This results in the overhead of having to build the whole model (which might consist of multiple sub models) the first time the model is run even though only a smaller part (e.g. Pretraining Model) is required.

Another issue arises when using different optimizers in different training stages (e.g. SGD first then Adam). As Adam creates additional variables Adam has to be instantiated during the first training stage so restoring from a checkpoint does not fail when restoring with Adam instead of SGD.

This restriction of having to build everything despite parts not being required results in more complicated code. If it would be possible to silently fail when a variable is not found in a checkpoint, so it can be initialized with tf.global_variables_initializer() instead, would allow better structuring of code.

I have looked through all current issues regarding this problem and I have found a couple that face a similar problem and where a QuietlyFailRestoringSaver could solve this problem: https://github.com/tensorflow/tensorflow/issues/12032 https://github.com/tensorflow/tensorflow/issues/16781

I might consider building this if there is enough support for it. I am open for feedback.

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Comments: 30 (28 by maintainers)

Most upvoted comments

Small example:

import tensorflow as tf
import tensorflow.contrib.eager as tfe
from tensorflow.python.training import checkpointable
from tensorflow.contrib.eager.python import checkpointable_utils


class MyModel(checkpointable.Checkpointable):

  def __init__(self):
    self.optimizer = tf.train.GradientDescentOptimizer(0.1)
    self.var = tf.get_variable('var', [])

  def optimizer_to_adam(self):
    self.optimizer = tf.train.AdamOptimizer()


checkpoint_path = 'modeldir/test/'

with tf.Graph().as_default():
  model = MyModel()
  checkpoint = tfe.Checkpoint(model=model)
  train_op = model.optimizer.minimize(model.var * model.var)
  print(checkpointable_utils._serialize_object_graph(checkpoint))
  session = tf.Session()
  session.run(tf.global_variables_initializer())
  checkpoint.save(checkpoint_path, session)
  session.run(train_op)
  print('Finished Model with SGD')

with tf.Graph().as_default():
  model = MyModel()
  model.optimizer_to_adam()
  latest = tf.train.latest_checkpoint(checkpoint_path)
  train_op = model.optimizer.minimize(model.var * model.var)
  checkpoint = tfe.Checkpoint(model=model)
  print(checkpointable_utils._serialize_object_graph(checkpoint))
  session = tf.Session()
  checkpoint.restore(latest).initialize_or_restore(session)
  session.run(train_op)
  print('Finished Model with Adam')

I would expect that the second graph (Finished Model with Adam) works. But it throws Attempting to use uninitialized value beta2_power. I expected the .initialize_or_restore to initialize the Adam variables but that does not happen.

They are printed by print(checkpointable_utils._serialize_object_graph(checkpoint)) though.

If you think that’s a bug I can open up a new issue to track it. I hope I am not using the API incorrectly.