tensorflow: [2.2rc3] Distibuted training with Keras and ThreadPoolDataset runs out of memory

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.2.0rc3 and tf-nightly
  • Python version: 3.7
  • CUDA/cuDNN version: 10.1 / 7.6.5.32
  • GPU model and memory: 4 x NVIDIA V100 on GCP

Describe the current behavior

When running the code below with cached training and validations datasets in a multi-GPU environment (I am using a GCP VM with 312GB of memory and 4 NVIDIA V100s) memory increases during each validation run until the VM runs out of memory. This behaviour can be observed on 2.2.0-rc3 and on the latest nightly.

It looks like the validation dataset is not properly cached since I can still see network access during validation and the memory usage drops below the theoretical cached memory requirements after validation has finished and then increases linearly during the next validation round to a point larger than the memory usage in the previous epoch. In the example below I intentionally use an very large validation set to make this memory increase very obvious and make training crash within the first 5 epochs. This behaviour can also be observed with other datasets, but the memory increase will be less noticible on smaller datsets.

In which cases is the memory usage still stable? To narrow down the possible causes for this I found two cases where this issue doesn’t exist:

  1. When running on a single GPU memory usage is stable.

  2. Tensorflow Datasets uses a _PrivateThreadPoolDataset by setting the experimental_threading.private_threadpool_size=16 as a default option. When disabling this option the memory usage is stable again. Unfortunately this is not a valid workaround in userland since the dataset option cannot be overwritten with experimental_threading.private_threadpool_size=None as it expects an integer.

@yhliang2018 @tomerk @byronyi This seems to be a complicated interaction between tf.data, tf.keras and tf.distribute, do you have an idea what could cause this behaviour? Please let me know what additional information I could provide. I’ve ran into similar issues with experimental_threading.private_threadpool_size on TF 2.0.0 in the past though never investigated the root cause in detail, so this might not be an entirely new regression.

Describe the expected behavior

Memory usage should be stable after the first epoch.

Standalone code to reproduce the issue

import tensorflow as tf
import tensorflow_datasets as tfds


batch_size = 1024

dataset = tfds.load(
    "imagenet2012:5.0.0",
    decoders={"image": tfds.decode.SkipDecoding()},
    split="train",
    data_dir="gs://my-cloud-bucket",
)

val_dataset = tfds.load(
    "imagenet2012:5.0.0",
    decoders={"image": tfds.decode.SkipDecoding()},
    split="validation",
    data_dir="gs://my-cloud-bucket",
)


def _decode_and_center_crop(image_bytes):
    """Crops to center of image with padding then scales image_size."""
    shape = tf.image.extract_jpeg_shape(image_bytes)
    image_height = shape[0]
    image_width = shape[1]
    image_size = 224

    padded_center_crop_size = tf.cast(
        (
            (image_size / (image_size + 32))
            * tf.cast(tf.minimum(image_height, image_width), tf.float32)
        ),
        tf.int32,
    )

    offset_height = ((image_height - padded_center_crop_size) + 1) // 2
    offset_width = ((image_width - padded_center_crop_size) + 1) // 2
    crop_window = tf.stack(
        [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]
    )
    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
    return tf.image.resize(image, [image_size, image_size], method="bicubic")


def preprocessing(data):
    return tf.cast(_decode_and_center_crop(data["image"]), tf.float32), data["label"]

dataset = (
    dataset.cache()
    .map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(batch_size)
    .prefetch(1)
)

val_dataset = (
    val_dataset.cache()
    .map(preprocessing, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(batch_size)
    .prefetch(1)
)

with tf.distribute.MirroredStrategy().scope():
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.GlobalMaxPool2D(input_shape=(224, 224, 3)),
            tf.keras.layers.Dense(1000, activation="softmax",),
        ]
    )

    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
    )

model.fit(
    val_dataset, epochs=5, validation_data=dataset,
)

Other info / logs

To monitor memory usage over time tools like ytop can be used.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 2
  • Comments: 31 (24 by maintainers)

Commits related to this issue

Most upvoted comments

@lgeiger after training for 5 hrs, TF Nightly version: 2.2.0.dev20200428 shows the flat memory utilization.

But for sure, I have another error on this version (cant use Tensorboard callback). Will make another issue

@lgeiger and @alimhanif thanks for verifying that the issue is fixed with tf-nightly

can we do a cherry picking this 2.2.0.dev20200428 version to the 2.2.0rc3?

@alimhanif This is in progress and will probably be included in the stable release of 2.2.0. See #38996

I just retested this with the latest nightly and it seems that this has been fixed by 7ebbab819e736319ec35b48e31f4d62fbad6626b as well. I’m still not sure what the interaction with the thread pool dataset is though.

Thanks @lgeiger, I ll ask some folks in XLA to take a look at that one.

Yes, I think those arguments are obsolete and should be removed. Thank you for your PR