tensorflow: tf.function with input_signature slower on unseen sequence length
System information
- Have I written custom code: Yes
- OS Platform and Distribution: Linux Ubuntu 16.04:
- TensorFlow installed from: binary
- TensorFlow version:
2.0.0-dev20190526
- Python version: 3.6.6
- CUDA/cuDNN version: 10.0
- GPU model and memory: GTX 1060
Describe the current behavior
When running a tf.function
on 3D inputs (a batch of sequence), I found that the execution is slower on unseen sequence length even if a compatible input_signature
is set. On a large graph, this results in a low GPU usage and a growing CPU memory usage for several iterations until most sequence lengths are seen. It is as if new graphs were compiled internally even though the function does not seem to be retraced.
This effect does not affect eager mode or V1 graph mode where the execution directly runs at its target speed and memory usage.
Describe the expected behavior
tf.function
with an input signature should behave like graph mode with constant memory usage and no “warmup” phase.
Code to reproduce the issue
While this issue is very visible on large graphs, I tried to compile a small example to consistently show the effect:
import time
import itertools
import random
import tensorflow as tf
def generate_token_based_shapes(num_tokens=4096):
while True:
length = random.randint(1, 100)
batch_size = int(num_tokens / length)
yield (batch_size, length)
# Generate 500k tensors of shape [None, None, 512] but with similar total size.
shapes = list(itertools.islice(generate_token_based_shapes(), 500000))
dataset = tf.data.Dataset.from_tensor_slices(shapes)
dataset = dataset.shuffle(len(shapes))
dataset = dataset.map(lambda shape: tf.zeros(tf.concat([shape, [512]], axis=0)))
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
# Define a model with some layers.
model = tf.keras.Sequential([
tf.keras.layers.Dense(1024),
tf.keras.layers.Dense(1024),
tf.keras.layers.Dense(1024),
tf.keras.layers.Dense(1024),
tf.keras.layers.Dense(1024)])
@tf.function(input_signature=(tf.TensorSpec([None, None, 512], dtype=tf.float32),))
def run_step(inputs):
return model(inputs)
seen_lengths = set()
for x in dataset:
length = x.shape[1]
start = time.time()
_ = run_step(x)
end = time.time()
print(length in seen_lengths, end - start)
seen_lengths.add(length)
Other info / logs
The above code produced the following logs when run on CPU:
False 0.43003296852111816
False 0.11496973037719727
False 0.11308979988098145
False 0.11620664596557617
False 0.11439895629882812
False 0.11322546005249023
True 0.095062255859375
False 0.11357808113098145
False 0.11438512802124023
False 0.11338496208190918
False 0.1123197078704834
False 0.11295366287231445
False 0.11250948905944824
False 0.11576318740844727
False 0.1139533519744873
False 0.11278915405273438
False 0.11090493202209473
True 0.09256935119628906
False 0.11287093162536621
False 0.11374545097351074
False 0.11446619033813477
False 0.11277508735656738
False 0.11354255676269531
False 0.11325383186340332
False 0.1137855052947998
False 0.11451315879821777
False 0.11423110961914062
True 0.09340834617614746
False 0.1146705150604248
False 0.11285781860351562
False 0.11371898651123047
True 0.09309053421020508
True 0.09239482879638672
True 0.09140896797180176
False 0.11467862129211426
False 0.11377716064453125
False 0.11178278923034668
False 0.11260485649108887
True 0.09450674057006836
True 0.09363818168640137
True 0.09272456169128418
False 0.11517977714538574
False 0.11325454711914062
True 0.09257698059082031
False 0.11360836029052734
True 0.09241485595703125
False 0.11343145370483398
True 0.09368515014648438
False 0.11366653442382812
True 0.09125065803527832
False 0.1126089096069336
False 0.11182904243469238
True 0.09548735618591309
True 0.09283709526062012
When the length is unseen, it takes about 0.113s but 0.092s after that.
On this example the effect is small but I’m trying to train a Transformer model with tf.function
and it takes very long for the training to reach full speed. The CPU memory usage also keeps growing during this “warmup” phase. The same model works well when integrated with tf.estimator
as I’m trying to move from Estimator to V2 custom loops.
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 48 (25 by maintainers)
@georgesterpu Would it work for you to catch the exception outside the
tf.function
?If this help, I’m using this construct to iterate on datasets. The decorated function is turned into a callable that returns a generator over its outputs. The exception is used as the end condition.
+1 @guillaumekln , catching the exception outside the tf.function should work as well. That would simplify your code quite a bit, and works so long as you don’t need to wrap the outer loop in a function.
As for avoiding assumptions on the structure, or to write it by hand, we can craft a version of map_structure that can work with None shapes:
Ah, I suspect the error comes from trying to call tf.zeros with a partially-known shape. Then my nest trick doesn’t work 😦. Instead you’ll need to craft a structure by hand. Something like this should work - the important part it to use a structure and shapes consistent with those created by your dataset. Based on the error message, this should work:
As for signaling the end of an epoch, I recommend setting up a separate return value. As long as you avoid using
None
, things should work:Lmk if that works.
Here’s a snippet that shows the use of Optional. Note the tf.nest trick to handle any structure your dataset may have. The choice of zeros for default value is entirely arbitrary:
With autograph:
With tf.cond:
The two examples are equivalent.
@georgesterpu the result of
get_next_as_optional
is an Optional which means you will have usetf.cond
to check whether it contains a value (and if so extract the value usingget_value
) before indexing it.@georgesterpu TF graphs don’t support exception handling, so you need to use alternative methods. I recommend either get_next_as_optional or rephrasing your loop using dataset.reduce and friends, like we do internally in autograph https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/operators/control_flow.py#L584
CC @jsimsa and @mdanatg
Sorry to bump this issue again but will this get fixed for the final TensorFlow 2.0 release? If not, maybe the documentation should clarify that there is a second level of cache in the TensorFlow runtime in addition to the
tf.function
input signatures.While this issue can be mitigated when using datasets as shown above, I also found that functions loaded from a V2
SavedModel
are barely usable because of it.