NVTabular: [BUG] NVTabular data loader for TensorFlow validation is slow

Describe the bug Using NVTabular data loader for TensorFlow for validation with criteo dataset is slow:

    validation_callback = KerasSequenceValidater(valid_dataset_tf)
    history = model.fit(train_dataset_tf, 
                        epochs=EPOCHS, 
                        steps_per_epoch=20, callbacks=[validation_callback])

Training : 2min for 2288 steps Validation: Estimated 55min for 3003 steps Same batch-size, dataset, etc. … The validation dataset is 1.3x bigger, but iterating through the validation dataset takes 27x more time than for training.

Steps/Code to reproduce bug An example is provided here: https://github.com/bschifferer/NVTabular/blob/criteo_tf_slow/examples/criteo_tensorflow_slow.ipynb

You notice that iterating over the training dataset takes in average ~1sec per 10 batches but it takes 4-6s per 10 batches in the validation dataset.

Expected behavior Validation loop should be similar fast than training loop

Additional context There are multiple hypothesis and tests:

  • This behavior can be observed by just iterating over the dataset and execute the forward-pass of the model
  • If we switch training/validation dataloader, then the validation dataloader is fast and the training dataloader is slow. Meaning, that a iteration over the 2nd data loader. Hypothesis is that the GPU memory is not released from the 1st data loader and blocks the pipeline
  • If we remove the forward-pass of the model in the loop, then both iterations are fast. It has probably something to do with moving the data to the TensorFlow model
  • I tried out using tf.keras.backend.clear_session() between the iterations, but did not help
  • I tried out between the iterations
from numba import cuda
cuda.select_device(0)
cuda.close()

but resulted in an error

  • I tried out to use separate subprocesses, but did not improve performance

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 15 (15 by maintainers)

Most upvoted comments

@jperez999 we think this is happening in the TF dataloader only. @bschifferer will confirm by testing PyTorch. Can you take a look please.