tensorflow: tf.saved_model.save very slow with second-order tf.autodiff.ForwardAccumulator

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): dockerhub container ‘latest’ Digest: 7bc36fe0ca1a051a808122e87f5438614b371263515df4794abef9a78440af8b
  • GPU model and memory: No gpu 4xIntel® Core™ i7-8650U CPU @ 1.90GHz 32 GB RAM

Describe the current behavior

Saving a tf.module involving second-order tf.autodiff.ForwardAccumulator takes too much time; 1 hour for the example below

Describe the expected behavior

Saving the graph in the example should take few seconds

Standalone code to reproduce the issue

import os
import tensorflow as tf
import time

class Issue_fwd(tf.Module):

    @tf.function(input_signature=[tf.TensorSpec([None, 1], tf.float64)] * 3 +
                                 [tf.TensorSpec([None, 3], tf.float64)] +
                                 [tf.TensorSpec([1, None], tf.float64)] * 4)
    def f(self, x1, x2, x3, c, v1, v2, v3, v4):

        with tf.autodiff.ForwardAccumulator(x1, tf.ones_like(x1)) as fwd_acc_x1_2, \
                tf.autodiff.ForwardAccumulator(x2, tf.ones_like(x2)) as fwd_acc_x2_2:

            with tf.autodiff.ForwardAccumulator(x1, tf.ones_like(x1)) as fwd_acc_x1, \
                 tf.autodiff.ForwardAccumulator(x2, tf.ones_like(x2)) as fwd_acc_x2, \
                 tf.autodiff.ForwardAccumulator(x3, tf.ones_like(x3)) as fwd_acc_x3:

                p = tf.concat([x1, x2, x3], axis=1)
                pe = tf.transpose(a=p[:, :, None], perm=[0, 2, 1])
                ce = tf.transpose(a=c[:, :, None], perm=[2, 0, 1])
                r = tf.reduce_sum(input_tensor=tf.square(ce - pe), axis=2)
                G = tf.exp(-r / 2)

                p = tf.reduce_sum(input_tensor=G * v1, axis=1, keepdims=True)
                b = tf.reduce_sum(input_tensor=G * v2, axis=1, keepdims=True)
                u = tf.reduce_sum(input_tensor=G * v3, axis=1, keepdims=True)
                w = tf.reduce_sum(input_tensor=G * v4, axis=1, keepdims=True)

            dpdx = fwd_acc_x1.jvp(p, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dbdx = fwd_acc_x1.jvp(b, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dudx = fwd_acc_x1.jvp(u, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dwdx = fwd_acc_x1.jvp(w, unconnected_gradients=tf.UnconnectedGradients.ZERO)

            dpdz = fwd_acc_x2.jvp(p, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dbdz = fwd_acc_x2.jvp(b, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dudz = fwd_acc_x2.jvp(u, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dwdz = fwd_acc_x2.jvp(w, unconnected_gradients=tf.UnconnectedGradients.ZERO)

            dbdt = fwd_acc_x3.jvp(b, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dudt = fwd_acc_x3.jvp(u, unconnected_gradients=tf.UnconnectedGradients.ZERO)
            dwdt = fwd_acc_x3.jvp(w, unconnected_gradients=tf.UnconnectedGradients.ZERO)

        d2ud2x = fwd_acc_x1_2.jvp(dudx, unconnected_gradients=tf.UnconnectedGradients.ZERO)
        d2ud2z = fwd_acc_x2_2.jvp(dudz, unconnected_gradients=tf.UnconnectedGradients.ZERO)

        d2wd2x = fwd_acc_x1_2.jvp(dwdx, unconnected_gradients=tf.UnconnectedGradients.ZERO)
        d2wd2z = fwd_acc_x2_2.jvp(dwdz, unconnected_gradients=tf.UnconnectedGradients.ZERO)

        d2bd2x = fwd_acc_x1_2.jvp(dbdx, unconnected_gradients=tf.UnconnectedGradients.ZERO)
        d2bd2z = fwd_acc_x2_2.jvp(dbdz, unconnected_gradients=tf.UnconnectedGradients.ZERO)

        return dudx, dudz, dudt, dwdx, dwdz, dwdt, dbdx, dbdz, dbdt, dpdx, dpdz,  d2ud2x, d2ud2z, d2wd2x, d2wd2z, d2bd2x, d2bd2z,


f = Issue_fwd()
saving_path = 'save_path'
os.makedirs(saving_path, exist_ok=True)

start_time = time.clock()
tf.saved_model.save(f, saving_path)
delta_time = time.clock() - start_time
print('saving took {:f} seconds'.format(delta_time))
print('tf.version.GIT_VERSION={}'.format(tf.version.GIT_VERSION))
print('tf.version.VERSION={}'.format(tf.version.VERSION))

Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

saving took 3942.697280 seconds
tf.version.GIT_VERSION=v2.3.0-rc2-23-gb36436b087
tf.version.VERSION=2.3.0

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 27 (15 by maintainers)

Most upvoted comments

Ah, batching multiple tangents associated with one primal? A GSoC student started on built-in tf.vectorized_map integration, but unfortunately it’s not complete (e.g. won’t work for higher-order yet because its tf.function integration has issues).

You can use it with tf.vectorized_map yourself for sure: https://github.com/tensorflow/tensorflow/blob/3da9cc88d244c415a6ac1c0b81fb0983809829b0/tensorflow/python/eager/forwardprop_test.py#L93-L97

One issue is that tf.vectorized_map traces the function at the moment rather than executing it eagerly.

You can also run forwardprop in a Python for loop if there aren’t many tangents. That’d happen eagerly.

I think the short answer is that tf.function produces function call operations, and the gradients/jvps of call operations are more complicated than the equivalent eager operations, especially if you’re nesting to take higher-order gradients. Usually graph optimizations re-simplify, but apparently not in this case. There’s a plan to unify these and always do the thing we do in eager right now, but it won’t happen very quickly.

I see why you get that error from experimental_compile with Issue_fwd.f but not ff decorated. It looks like it was added in https://github.com/tensorflow/tensorflow/commit/6a6261c0a0e803891af95f5e754180739df1897d

@yunxing do we need to register a gradient for xla_dynamic_update_slice? Is there a bug for it?