tensorflow: Dataset prefetch not working as expected, not storing data in memory
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04 LTS
- TensorFlow installed from (source or binary): conda-forge
- TensorFlow version (use command below): unknown 1.14.0
- Python version: Python 3.7.3
- CUDA/cuDNN version: NVIDIA-SMI 418.67, Driver Version: 418.67, CUDA Version: 10.1
- GPU model and memory: Quadro RTX 6000, 24190MiB
- Exact command to reproduce:
Describe the current behavior
I am training a small LSTM model and until recently, I could use Dataset.from_tensor_slices() reading numpy arrays directly because all training data fits into memory. Unfortunately, after adding some new data, I ran into the 2 GB graph memory limitation and was forced to switch to using TFRecord and TFRecordDataset. However, the actual training data still fits into RAM and I want to make sure it is prefetched even when using the TFRecordDataset. Therefore, I tried to use the Dataset.prefetch() methodology to achieve this, assuming a buffer will be created and (constantly) filled with data. However, it does not work - in fact, there seems to be little to no difference comparing a version with and without a final .prefetch(x) in the data pipeline. See the animated gif below:

The actual dataset is filtered in the pipeline and the training stalls whenever a sequence of values that is filtered out is occurring in the data. Only a few values in each data/tfrecord file (of which many exist) are relevant. To illustrate this further, the data layout is similar to this:
file 1: [---------#####----------------####------]
file 2: [--------######----------------###-------]
file 3: [----------####----------------####------]
file 4: ...
...
where - denotes irrelevant and # denotes relevant data points in a time series. When holding all values in memory (as previously was the case), the filter is rather fast and irrelevant values are skipped unnoticeable.
The data pipeline is set up like this:
feature_description = { "features": tf.FixedLenFeature([132], tf.float32), "label": tf.FixedLenFeature([1], tf.float32) }
def _parse_function(example_proto):
return tf.parse_single_example(example_proto, feature_description)
ds = tf.data.TFRecordDataset([f.as_posix() for f in fs_train])
ds = ds.map(_parse_function)
ds = ds.flat_map(lambda v: tf.data.Dataset.from_tensors((v["features"][2:], v["label"])))
# filter data, only allow Ls[0] and Ls[1]
ds = ds.filter(
lambda _, y: tf.reshape(tf.logical_or(
tf.equal(y, Ls[0]),
tf.equal(y, Ls[1])
), [])
)
# relabel and re-map labels to 0 and 1
ds = ds.flat_map(lambda x, y: tf.data.Dataset.from_tensors((x, tf_relabel(y) - base_label.value)))
# create sliding window for LSTM
ds = ds.window(size=window_size, shift=shift, stride=stride, drop_remainder=True)
ds = ds.flat_map(lambda x, y: tf.data.Dataset.zip((x.batch(window_size), y.batch(window_size))))
# batch and prefetch
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(1000000000000000) # tried many values, nothing works
Describe the expected behavior
I expect to find some value for Dataset.prefetch() that reads all or enough data to memory to allow for fast training without stalling.
Code to reproduce the issue See data pipeline above. I cannot provide the data as it is proprietary.
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 2
- Comments: 20 (9 by maintainers)
@mimxrt if you are purely aiming for throughput, you might be better off doing offline filtering. Also there are a couple of low hanging fruits viz.
num_parallel_callsarg inmapfor parallel processing of dataset elementstf.data.experimental.parallel_interleaveortf.data.Dataset.interleavewhen processing tfrecords. However parallel_interleave is deprecated.tf.data.experimental.AUTOTUNEforbuffer_sizeinprefetchandnum_parallel_callsinmaptf.data.OptionsThank you for your kind offer, that’s really nice. However, due to project deadlines, I had to switch to pre-processing the data in Dask and storing already the pre-processed data to tfrecord files. Doing this, the Tensorflow Dataset-pipeline is reduced to merely reading the files and the training is accelerated. Of course, I am still interested in using the more elegant way of pre-processing the data in parallel on-the-fly, but I would feel bad if you invested a lot of your time into this while there is no emergency.
Additionally, the data is proprietary, although, of course, I could provide you with fake data of the same shape and size. So, if you are interested in this as well, let me know and I will create a minimal working example including fake data. In either case, thank you very much for your kind support!
Thank you for your comment. I understand and I am currently (hopefully) profiling the model fitting to confirm your hypothesis. Until now I am struggling to get TensorBoard to visualize the profile. Anyway, I wanted to let you know that I am still looking into this.
As a side note:
Dataset.cache()seems to work when using a hard disk file instead of RAM. The second epoch forward is now as fast as I would expect which further supports your hypothesis.EDIT: I also want to stress, that something is still amiss with the data pipeline even if
prefetch()turns out to work as expected. The performance of epoch 0 (i.e. empty cache) is extremely slow compared to reading and pre-processing the data with Dask/Pandas and feeding plain numpy arrays to the Dataset API. It could be that the pipeline does not read/process data in parallel but I am unsure how to enable this feature.num_parallel_callsandtf.data.Dataset.interleavehad no effect for me so far. I can confirm that the CPU usage is small compared to the Dask pipeline when using the TF Dataset API.prefetchworks as expected. It decouples the producer from consumer, using an internal buffer. What I suspect is going on is that your input pipeline is running slower than your training step, which means that you get little benefits from preprocessing.You could confirm this hypothesis by separately benchmarking a) the latency of your input pipeline and b) the latency of your model with synthetic data. If a) is much higher than b) you will get little benefits from prefetching and you should instead focus on optimizing the performance of you input pipeline through parallelization. My recommendation would be to use Tensorboard Keras profiler to understand what is going on in your input pipeline. If you share a link to your trace, I would be happy to provide insights.