tensorflow: Slowdown of `tf.scan` operation in TF 2.5

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS Mojave Version 10.14.5
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: -
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v2.5.0-rc3-213-ga4dfb8d1a71 2.5.0
  • Python version: 3.7.5
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: -
  • GPU model and memory: -

Describe the current behavior

There seems to be a considerable slowdown whenever a tf.scan operation is used inside the loss function of a model architecture and a learnable parameter (self.transition_param below) of the model architecture is part of the tf.scan op. Here’s a code snippet to reproduce this:

import tensorflow as tf
import tensorflow_addons as tfa
from tqdm import tqdm
import numpy as np


def gen_batches(num_batches, batch_size, units):

    for _ in range(num_batches):
        x = np.random.random((batch_size, 20, units))
        y = np.random.randint(1, units, size=(batch_size, 20))
        yield x, y


class MyModel(tf.keras.models.Model):

    def __init__(self, units):
        super().__init__()

        self._tf_layers = {}

        self.units = units

        self.transition_param = self.add_weight(name="transition_param", shape=(units, units))

        self.optimizer = tf.keras.optimizers.Adam()
        self._training = False

    def _loss_fn_with_scan(self, inputs, transition_params):

        first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
        first_input = tf.squeeze(first_input, [1])

        rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])

        rest_of_input = tf.transpose(rest_of_input, [1, 0, 2])
        transition_params = tf.expand_dims(transition_params, 0)

        def _scan_fn(_state, _inputs):
            _state = tf.expand_dims(_state, 2)
            transition_scores = _state + transition_params
            new_alphas = _inputs + tf.reduce_logsumexp(transition_scores, [1])
            return new_alphas

        all_alphas = tf.transpose(tf.scan(_scan_fn, rest_of_input, first_input), [1, 0, 2])
        # add first state for sequences of length 1
        all_alphas = tf.concat([tf.expand_dims(first_input, 1), all_alphas], 1)

        return all_alphas

    def _loss(self, x, y):

        logits = tf.cast(x, dtype=tf.float32)

        loss = self._loss_fn_with_scan(logits, self.transition_param)

        return tf.reduce_mean(loss)

    @tf.function
    def train_on_batch(self, *args):
        with tf.GradientTape(persistent=True) as tape:
            loss = self._loss(*args)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return loss

    def train(self, epochs, batch_size, num_batches):

        data_generator_iter = gen_batches(num_batches, batch_size, self.units)

        sample_x, sample_y = next(data_generator_iter)

        self.train_on_batch(sample_x, sample_y)

        self._training = True

        progress_bar = tqdm(range(epochs), desc="Epochs")

        for epoch in progress_bar:
            for batch_x, batch_y in data_generator_iter:
                loss = self.train_on_batch(batch_x, batch_y)

            progress_bar.update(1)
            progress_bar.set_postfix({"loss": f"{loss.numpy():.3f}"})


num_batches = 5000
batch_size = 32
units = 64
epochs = 100

model = MyModel(units)
model.train(epochs, batch_size, num_batches)

The same code takes ~ 3 mins, 16 seconds on TF 2.3 . However, it takes ~ 4 mins and 1 second on TF 2.5 .

The above code is a stripped down version of a model architecture which uses transformers + CRF layer. The _loss_fn_with_scan is taken from crf_log_norm function of tensorflow-addons package but as the snippet shows, the problem is already visible without the CRF bits of the code and just using tf.scan as part of the loss function (tf.scan is part of tensorflow repo, hence the issue seems appropriate here). On large datasets, the slowdown is considerably large. For example, with TF 2.3 it took 1 hour 20 mins to complete the training and with TF 2.5 it takes close to 3 hours to complete the training which is a considerable increase and a big blocker for us to upgrade to TF 2.5.

Describe the expected behavior

Training times should be comparable across TF 2.3 and TF 2.5 .

Contributing

  • Do you want to contribute a PR? (yes/no): maybe
  • Briefly describe your candidate solution(if contributing): I am not familiar with the internals of tf.scan op. If there’s some help available, we could give it a try.

Standalone code to reproduce the issue A snippet is available above.

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 42 (18 by maintainers)

Most upvoted comments

Hi, all! Just wanted to let you know we’re investigating – no conclusions yet, but I have a couple leads I’m following. I’ll post back when I have more. Thanks for the info so far!

Greetings! Sure, let me take another look – I wasn’t entirely sure if I was awaiting GPU data . (I also admit I’m making heavy use of other folks’ expertise on this one, and alas some of those people are now on different teams over the holidays.) The timing is good though, I’m just wrapping up one project so I have a little time to see what we can do here.