tensorflow: Memory leak when using py_function inside tf.data.Dataset

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
  • OS Platform and Distribution: Linux Ubuntu 16.04
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):
  • TensorFlow version (use command below): 2.0
  • Python version: 3.6.8
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

Describe the current behavior

屏幕快照 2019-12-13 下午6 06 20

Describe the expected behavior

The tf.data.Dataset instance should be freed in every step.

Code to reproduce the issue

import tensorflow as tf
import os
import numpy as np
import psutil

def _generator():
    for i in range(100):
        yield "1,2,3,4,5,6,7,8"

def _py_parse_data(record):
    record = record.numpy()
    record = bytes.decode(record)
    rl = record.split(",")
    rl = [str(int(r) + 1) for r in rl]
    return [",".join(rl)]

def parse_data(record, shape=10):
    sparse_data = tf.strings.split([record], sep=",")
    sparse_data = tf.strings.to_number(sparse_data[0], tf.int64)
    ids_num = tf.cast(tf.size(sparse_data), tf.int64)
    indices = tf.range(0, ids_num, dtype=tf.int64)
    indices = tf.reshape(indices, shape=(-1, 1))
    sparse_data = tf.sparse.SparseTensor(
                indices, sparse_data, dense_shape=(shape,)
    )
    return sparse_data

process = psutil.Process(os.getpid())

step = 0
while (step < 10000):
    t = tf.data.Dataset.from_generator(_generator, output_types=tf.string)
    t = t.map(lambda record: tf.py_function(_py_parse_data, [record], [tf.string]))
    t = t.map(parse_data)
    for d in t:
        a = 1
    if step % 10 == 0:
        print("Memory : ", process.memory_info().rss)
    step += 1

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 5
  • Comments: 15 (9 by maintainers)

Commits related to this issue

Most upvoted comments

Closing as stale. Please reopen if you’d like to work on this further.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

@QiJune, Is this still an issue? On running the code with TF v2.2, I did not observe much difference between each iteration. Please find the gist of it here. Thanks!

@kkimdev @jsimsa hello, we are having the same problem with TF 1.14. We use tf.py_function to load a wave file:

results = tf.py_function(
                self.safe_load,
                [audio_descriptor, offset, duration, sample_rate, dtype],
                (tf.float32, tf.bool)),
            waveform, error = results[0]

putting this into a tf.Dataset:

dataset = dataset.map(
        lambda sample: dict(
            sample,
            **audio_adapter.load_tf_waveform(
                sample['audio_id'],
                session=session,
                sample_rate=sample_rate,
                offset=sample['start'],
                duration=sample['end'] - sample['start'])),
        num_parallel_calls=2)

and getting a leak, where the memory leaked is the size of the wave file being loaded:

after prediction traced memory: 28670 KiB  peak: 28673 KiB  overhead: 29677 KiB
after load traced memory: 28801 KiB  peak: 28808 KiB  overhead: 29755 KiB
after prediction traced memory: 53988 KiB  peak: 55396 KiB  overhead: 54529 KiB
after load traced memory: 54100 KiB  peak: 55396 KiB  overhead: 54604 KiB

Ok gotcha, but so you don’t instantiate the dataset in a loop?

@loretoparisi do you also create the dataset in a for loop or do you instantiate it only once ?

I am asking because I suspect a memory leak as well, but I am only creating one dataset object and then training on it using fit. On my side, I use tf.py_function to load HDF5 files because of this error in tfio.

I would also be interested in the script you used to get the last lines of your post.

List of leaking objects per 100 iterations:

=======================================================================
Type                     Old_ids  Current_ids      New_ids Count_Deltas
=======================================================================
dict                       49562        50462         +904         +900
cell                       19873        20673         +800         +800
tuple                      39958        40658         +700         +700
function                   71183        71783         +600         +600
list                       27319        27819         +502         +500
KeyedRef                    3800         4200         +400         +400
EagerTensor                 1900         2100         +200         +200
method                      1413         1513         +100         +100
_GeneratorState              950         1050         +100         +100
TensorShape                  953         1053         +100         +100
Tape                         950         1050         +100         +100
GradientTape                 950         1050         +100         +100
EagerFunc                    950         1050         +100         +100
StringIO                       3            3           +1           +0
wrapper_descriptor          3782         3782           +0           +0
weekday                       14           14           +0           +0
weakref                    13226        13226           +0           +0
weakcallableproxy              1            1           +0           +0
vectorize                      4            4           +0           +0
validate_nseq_int              1            1           +0           +0
validate_nseq_float            5            5           +0           +0
uname_result                   1            1           +0           +0
tzutc                          1            1           +0           +0
tzUTC                          1            1           +0           +0
type                        6277         6277           +0           +0
staticmethod                1212         1212           +0           +0
slice                         72           72           +0           +0
set                         6633         6633           +0           +0
scputimes                    113          113           +0           +0
pybind11_type                 49           49           +0           +0
=======================================================================

Leaking EagerFunc reference graph image

So seems like the problem is py_function getting created every loop and _py_funcs_used_in_graph keeps growing.