tensorflow: Tensorflow2.0 Training OOM when tf.function retraces excessively
When I am training my model using tf.function CPU memory is leaking and after some steps of training it is getting killed.
System information
- OS Platform: - Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary): binary
- TensorFlow version: 2.0-beta
- Python version: 3.7
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version: 10.01
- GPU model and memory: p100, 16GB
@tf.function
def train_step(self, x, step, grad_clip=True, clip_value=1.0):
with tf.name_scope("input_data"):
inputs = x[:, :-1]
targets = x[:, 1:]
with tf.GradientTape() as tape:
predictions, _ = model(inputs, training=True)
loss = self.get_loss(targets, predictions)
with tf.name_scope("gradients"):
gradients = tape.gradient(loss, self.trainable_variables)
if grad_clip:
gradients = [(tf.clip_by_value(grad, -clip_value, clip_value))
for grad in gradients]
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
accuracy = self.get_padded_accuracy(targets, predictions)
assert self.train_writer is not None
with tf.name_scope("summary_writer"):
with self.train_writer.as_default():
tf.summary.scalar("loss", loss, step=step)
tf.summary.scalar("accuracy", accuracy, step=step)
return loss, accuracy
**Stats of CPU memory
After 100 steps - 5.08 GB
After 500 steps - 14.8 GB
After 1000 steps - 27 GB**
When is graph mode of training using tf.function CPU memory gets leaked and after some steps, it gets killed but when I train my model using eager mode it works fine, I am using above code for graph mode training as recommended by TensorFlow 2.0 community.
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 1
- Comments: 23 (8 by maintainers)
That’s the reason why! Because
tf.functionwill re-trace the graph due to the variable sequence lengths or variable batch sizes, the number of cached graphs will increase over time if it do not see this shape before. Thus OOM occurred. To avoid re-tracing, you might want to specify theinput_signature. For detailed explanation, please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function https://www.tensorflow.org/beta/tutorials/text/transformer (searchtf.function)