tensorflow: `recompute_grad` does not save memory and is incompatible with graph mode

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No.
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04 and Windows 10.
  • TensorFlow installed from (source or binary): from binary (pip install)
  • TensorFlow version (use command below): 2.1.0
  • Python version: 3.7
  • CUDA/cuDNN version: CUDA10.2+CuDNN7.6.5 (Windows), CUDA10.1+CuDNN7.6.5+TensorRT 6 (Ubuntu),
  • GPU model and memory: GeForce GTX 1060 with Max-Q Design, 6GB (Windows) and GeForce GTX 1080 Ti, 12GB (Ubuntu)

Describe the current behavior Using tf.recompute_grad to wrap keras layers does not take any effect. I build a DenseNet model and wrap each “-bn-relu-conv1x1-bn-relu-conv” block by the function. But I have not seen any GPU memory reduction on both the Windows and Ubuntu platforms. When eager mode is disabled, it throws “ValueError: Variable <tf.Variable ‘batch_normalization/gamma:0’ shape=(32,) dtype=float32> has None for gradient.”, indicating that using compute_grad blocks the gradient backpropagation in graph mode.

Describe the expected behavior The function seems to originate from OpenAI’s gradient checkpointing (https://github.com/cybertronai/gradient-checkpointing) and is expected to save GPU memory during training. Recently, a tensorflow implementation of efficient DenseNets (https://github.com/joeyearsley/efficient_densenet_tensorflow) also uses this function to perform the gradient checkpointing (they used tf.contrib.layers.recompute_grad in tf1 graph mode, not exactly the same environment as our case.)

Please fix the incompatibility bug so that the function can still work with the graph mode. If the function is designed to perform gradient checkpointing, please verify its effectiveness. If it is not supposed to implement efficient DenseNets, please provide the correct and effective implementation.

Standalone code to reproduce the issue

import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import app, flags
from absl.flags import FLAGS
from tensorflow import keras

flags.DEFINE_list("gpu",
                  default=None,
                  help="index of GPU")
flags.DEFINE_bool("recompute_grad",
                  default=False,
                  help="whether to recompute gradients to save GPU RAM")
flags.DEFINE_integer("batch_size",
                     default=1024,
                     help="batch size")
flags.DEFINE_bool("graph",
                  default=False,
                  help="use graph mode instead of eager mode")


def dense_lenet(inputs):
    net = keras.layers.Conv2D(32, 5, strides=2, use_bias=False, padding="SAME")(inputs)

    for _ in range(5):
        def _block(x):
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Conv2D(16, 1, use_bias=False, padding="SAME")(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Conv2D(4, 3, use_bias=False, padding="SAME")(x)
            return x
        if FLAGS.recompute_grad:
            _block = tf.recompute_grad(_block)
        net = keras.layers.concatenate([net, _block(net)])

    net = keras.layers.BatchNormalization()(net)
    net = keras.layers.ReLU()(net)
    net = keras.layers.Conv2D(64, 1, use_bias=False, padding="SAME")(net)
    net = keras.layers.AveragePooling2D()(net)

    for _ in range(10):
        def _block(x):
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Conv2D(32, 1, use_bias=False, padding="SAME")(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Conv2D(8, 3, use_bias=False, padding="SAME")(x)
            return x
        if FLAGS.recompute_grad:
            _block = tf.recompute_grad(_block)
        net = keras.layers.concatenate([net, _block(net)])

    net = keras.layers.BatchNormalization()(net)
    net = keras.layers.ReLU()(net)
    net = keras.layers.Conv2D(128, 1, use_bias=False, padding="SAME")(net)
    net = keras.layers.AveragePooling2D()(net)

    for _ in range(10):
        def _block(x):
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Conv2D(32, 1, use_bias=False, padding="SAME")(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Conv2D(8, 3, use_bias=False, padding="SAME")(x)
            return x
        if FLAGS.recompute_grad:
            _block = tf.recompute_grad(_block)
        net = keras.layers.concatenate([net, _block(net)])

    net = keras.layers.BatchNormalization()(net)
    net = keras.layers.ReLU()(net)
    net = keras.layers.GlobalAveragePooling2D()(net)

    net = keras.layers.Dense(10)(net)
    net = keras.layers.Softmax()(net)

    return net


def main(_):
    if FLAGS.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, FLAGS.gpu))
    if FLAGS.graph:
        tf.compat.v1.disable_eager_execution()
        tf.compat.v1.keras.backend.set_session(
            session=tf.compat.v1.Session(
                config=tf.compat.v1.ConfigProto(
                    gpu_options=tf.compat.v1.GPUOptions(
                        allow_growth=True
                    )
                )
            )
        )
    else:
        for gpu in tf.config.experimental.list_physical_devices('GPU'):
            tf.config.experimental.set_memory_growth(gpu, True)

    tfds.core.constants.DATA_DIR = "data"
    dataset_builder = tfds.image.FashionMNIST(version="3.*.*")
    dataset_builder.download_and_prepare()
    dataset = dataset_builder.as_dataset(
        split="train",
        shuffle_files=True,
        as_supervised=True,
    ).repeat().batch(FLAGS.batch_size)

    inputs = keras.layers.Input((28, 28, 1), batch_size=FLAGS.batch_size)
    model = keras.Model(inputs, dense_lenet(inputs))

    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    model.summary()

    model.fit(
        x=dataset,
        epochs=3,
        steps_per_epoch=60000//FLAGS.batch_size,
    )


if __name__ == "__main__":
    app.run(main)

About this issue

  • Original URL
  • State: open
  • Created 4 years ago
  • Reactions: 1
  • Comments: 15 (2 by maintainers)

Most upvoted comments

@BinyanHu if you’re interested, I’ve written a simple gradient checkpointing decorator here

If you’re looking to do gradient checkpointing in graph mode I suggest the implementation tf-slim here, which I’ve extracted and successfully tested on tf-nightly in graph mode on TPU: https://github.com/google-research/tf-slim/blob/a62dc893de5e46e6f2e9ec24a74b2abce026307a/tf_slim/layers/rev_block_lib.py

@paulter I have a version working with Keras but sequential models only. I have create a pull request as part of TF addons github repo - https://github.com/tensorflow/addons/pull/1600. You can find an example notebook here - https://github.com/pidajay/addons/blob/grad_checkpointing_eager/docs/tutorials/training_gradient_checkpointing.ipynb

@BinyanHu I’ve got a minimal example showing that it doesn’t work for memory reduction over here: https://github.com/tensorflow/tensorflow/issues/30418#issuecomment-589820336