tensorflow: Calling tf.function from tf.py_function in dataset.map hangs.
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes.
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS or Windows.
- TensorFlow installed from (source or binary): binary.
- TensorFlow version (use command below):2.0.0b1/rc0/rc1
- Python version: 3.6.
- CUDA/cuDNN version: 10.0
- GPU model and memory: Tesla K80
Describe the current behavior Calling tf.function from tf.py_function in dataset.map hangs the program. By removing tf.function decorator or enable run_functions_eagerly, program runs as expected.
Describe the expected behavior Calling tf.function from tf.py_function does not hang the program.
Code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem.
import tensorflow as tf
@tf.function
def generate_feature(key):
if key > tf.constant(-1):
x = tf.random.uniform(shape=(), minval=0.0, maxval=20.0)
y = tf.random.uniform(shape=(), minval=0.0, maxval=20.0)
else:
x = tf.random.uniform(shape=(), minval=40.0, maxval=100.0)
y = tf.random.uniform(shape=(), minval=40.0, maxval=100.0)
x = tf.math.sign(tf.random.uniform(shape=(), minval=-1.0, maxval=1.0)) * tf.abs(x)
y = tf.math.sign(tf.random.uniform(shape=(), minval=-1.0, maxval=1.0)) * tf.abs(y)
return tf.stack([x, y])
def generate_feature_and_label(key):
feature = generate_feature(key)
if key > -1:
label = tf.constant(1)
else:
label = tf.constant(0)
return feature, label
def dataset_map(key):
feature, label = tf.py_function(func=generate_feature_and_label,
inp=[key],
Tout=[tf.float32, tf.int32])
feature = tf.ensure_shape(feature, [2])
label = tf.ensure_shape(label, [])
return feature, label
if __name__ == '__main__':
print(tf.__version__)
# remove tf.function decorator or enable run_functions_eagerly to run successfully
# tf.config.experimental_run_functions_eagerly(True)
keys = list(range(-1000, 1000))
dataset = tf.data.Dataset.from_tensor_slices(keys)
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=len(keys))
dataset = dataset.map(map_func=dataset_map,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(4)
it = iter(dataset)
while True:
features, labels = next(it)
print(f'labels={labels}, features={features}')
input('press any key to continue...\r\n')
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 18 (7 by maintainers)
+1 @mdanatg @makercob In your case: map_fn -> py_func -> tf.function, TF actually creates two graphs, one for map_fn and one for tf.function. It will introduce some extra overhead. For example ops in two graphs cannot be executed in parallel. @mdanatg’s suggestion can maximize your performance gain.
But anyway, we are fixing this hanging issue by avoiding using threads from interop threadpool to execute py_func. Will update the thread when the fix is in.
Sorry about the delay. TF runtime folks are looking into this bug. Will post updates here.