tensorflow: Slowdown of `tf.scan` operation in TF 2.5
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS Mojave Version 10.14.5
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: -
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): v2.5.0-rc3-213-ga4dfb8d1a71 2.5.0
- Python version: 3.7.5
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version: -
- GPU model and memory: -
Describe the current behavior
There seems to be a considerable slowdown whenever a tf.scan operation is used inside the loss function of a model architecture and a learnable parameter (self.transition_param below) of the model architecture is part of the tf.scan op. Here’s a code snippet to reproduce this:
import tensorflow as tf
import tensorflow_addons as tfa
from tqdm import tqdm
import numpy as np
def gen_batches(num_batches, batch_size, units):
for _ in range(num_batches):
x = np.random.random((batch_size, 20, units))
y = np.random.randint(1, units, size=(batch_size, 20))
yield x, y
class MyModel(tf.keras.models.Model):
def __init__(self, units):
super().__init__()
self._tf_layers = {}
self.units = units
self.transition_param = self.add_weight(name="transition_param", shape=(units, units))
self.optimizer = tf.keras.optimizers.Adam()
self._training = False
def _loss_fn_with_scan(self, inputs, transition_params):
first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = tf.squeeze(first_input, [1])
rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])
rest_of_input = tf.transpose(rest_of_input, [1, 0, 2])
transition_params = tf.expand_dims(transition_params, 0)
def _scan_fn(_state, _inputs):
_state = tf.expand_dims(_state, 2)
transition_scores = _state + transition_params
new_alphas = _inputs + tf.reduce_logsumexp(transition_scores, [1])
return new_alphas
all_alphas = tf.transpose(tf.scan(_scan_fn, rest_of_input, first_input), [1, 0, 2])
# add first state for sequences of length 1
all_alphas = tf.concat([tf.expand_dims(first_input, 1), all_alphas], 1)
return all_alphas
def _loss(self, x, y):
logits = tf.cast(x, dtype=tf.float32)
loss = self._loss_fn_with_scan(logits, self.transition_param)
return tf.reduce_mean(loss)
@tf.function
def train_on_batch(self, *args):
with tf.GradientTape(persistent=True) as tape:
loss = self._loss(*args)
grads = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
return loss
def train(self, epochs, batch_size, num_batches):
data_generator_iter = gen_batches(num_batches, batch_size, self.units)
sample_x, sample_y = next(data_generator_iter)
self.train_on_batch(sample_x, sample_y)
self._training = True
progress_bar = tqdm(range(epochs), desc="Epochs")
for epoch in progress_bar:
for batch_x, batch_y in data_generator_iter:
loss = self.train_on_batch(batch_x, batch_y)
progress_bar.update(1)
progress_bar.set_postfix({"loss": f"{loss.numpy():.3f}"})
num_batches = 5000
batch_size = 32
units = 64
epochs = 100
model = MyModel(units)
model.train(epochs, batch_size, num_batches)
The same code takes ~ 3 mins, 16 seconds on TF 2.3 . However, it takes ~ 4 mins and 1 second on TF 2.5 .
The above code is a stripped down version of a model architecture which uses transformers + CRF layer. The _loss_fn_with_scan is taken from crf_log_norm function of tensorflow-addons package but as the snippet shows, the problem is already visible without the CRF bits of the code and just using tf.scan as part of the loss function (tf.scan is part of tensorflow repo, hence the issue seems appropriate here). On large datasets, the slowdown is considerably large. For example, with TF 2.3 it took 1 hour 20 mins to complete the training and with TF 2.5 it takes close to 3 hours to complete the training which is a considerable increase and a big blocker for us to upgrade to TF 2.5.
Describe the expected behavior
Training times should be comparable across TF 2.3 and TF 2.5 .
- Do you want to contribute a PR? (yes/no): maybe
- Briefly describe your candidate solution(if contributing): I am not familiar with the internals of
tf.scanop. If there’s some help available, we could give it a try.
Standalone code to reproduce the issue A snippet is available above.
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 42 (18 by maintainers)
Hi, all! Just wanted to let you know we’re investigating – no conclusions yet, but I have a couple leads I’m following. I’ll post back when I have more. Thanks for the info so far!
Greetings! Sure, let me take another look – I wasn’t entirely sure if I was awaiting GPU data . (I also admit I’m making heavy use of other folks’ expertise on this one, and alas some of those people are now on different teams over the holidays.) The timing is good though, I’m just wrapping up one project so I have a little time to see what we can do here.