tensorflow: tf.saved_model.save very slow with second-order tf.autodiff.ForwardAccumulator
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): dockerhub container ‘latest’ Digest: 7bc36fe0ca1a051a808122e87f5438614b371263515df4794abef9a78440af8b
- GPU model and memory: No gpu 4xIntel® Core™ i7-8650U CPU @ 1.90GHz 32 GB RAM
Describe the current behavior
Saving a tf.module involving second-order tf.autodiff.ForwardAccumulator takes too much time; 1 hour for the example below
Describe the expected behavior
Saving the graph in the example should take few seconds
Standalone code to reproduce the issue
import os
import tensorflow as tf
import time
class Issue_fwd(tf.Module):
@tf.function(input_signature=[tf.TensorSpec([None, 1], tf.float64)] * 3 +
[tf.TensorSpec([None, 3], tf.float64)] +
[tf.TensorSpec([1, None], tf.float64)] * 4)
def f(self, x1, x2, x3, c, v1, v2, v3, v4):
with tf.autodiff.ForwardAccumulator(x1, tf.ones_like(x1)) as fwd_acc_x1_2, \
tf.autodiff.ForwardAccumulator(x2, tf.ones_like(x2)) as fwd_acc_x2_2:
with tf.autodiff.ForwardAccumulator(x1, tf.ones_like(x1)) as fwd_acc_x1, \
tf.autodiff.ForwardAccumulator(x2, tf.ones_like(x2)) as fwd_acc_x2, \
tf.autodiff.ForwardAccumulator(x3, tf.ones_like(x3)) as fwd_acc_x3:
p = tf.concat([x1, x2, x3], axis=1)
pe = tf.transpose(a=p[:, :, None], perm=[0, 2, 1])
ce = tf.transpose(a=c[:, :, None], perm=[2, 0, 1])
r = tf.reduce_sum(input_tensor=tf.square(ce - pe), axis=2)
G = tf.exp(-r / 2)
p = tf.reduce_sum(input_tensor=G * v1, axis=1, keepdims=True)
b = tf.reduce_sum(input_tensor=G * v2, axis=1, keepdims=True)
u = tf.reduce_sum(input_tensor=G * v3, axis=1, keepdims=True)
w = tf.reduce_sum(input_tensor=G * v4, axis=1, keepdims=True)
dpdx = fwd_acc_x1.jvp(p, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dbdx = fwd_acc_x1.jvp(b, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dudx = fwd_acc_x1.jvp(u, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dwdx = fwd_acc_x1.jvp(w, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dpdz = fwd_acc_x2.jvp(p, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dbdz = fwd_acc_x2.jvp(b, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dudz = fwd_acc_x2.jvp(u, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dwdz = fwd_acc_x2.jvp(w, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dbdt = fwd_acc_x3.jvp(b, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dudt = fwd_acc_x3.jvp(u, unconnected_gradients=tf.UnconnectedGradients.ZERO)
dwdt = fwd_acc_x3.jvp(w, unconnected_gradients=tf.UnconnectedGradients.ZERO)
d2ud2x = fwd_acc_x1_2.jvp(dudx, unconnected_gradients=tf.UnconnectedGradients.ZERO)
d2ud2z = fwd_acc_x2_2.jvp(dudz, unconnected_gradients=tf.UnconnectedGradients.ZERO)
d2wd2x = fwd_acc_x1_2.jvp(dwdx, unconnected_gradients=tf.UnconnectedGradients.ZERO)
d2wd2z = fwd_acc_x2_2.jvp(dwdz, unconnected_gradients=tf.UnconnectedGradients.ZERO)
d2bd2x = fwd_acc_x1_2.jvp(dbdx, unconnected_gradients=tf.UnconnectedGradients.ZERO)
d2bd2z = fwd_acc_x2_2.jvp(dbdz, unconnected_gradients=tf.UnconnectedGradients.ZERO)
return dudx, dudz, dudt, dwdx, dwdz, dwdt, dbdx, dbdz, dbdt, dpdx, dpdz, d2ud2x, d2ud2z, d2wd2x, d2wd2z, d2bd2x, d2bd2z,
f = Issue_fwd()
saving_path = 'save_path'
os.makedirs(saving_path, exist_ok=True)
start_time = time.clock()
tf.saved_model.save(f, saving_path)
delta_time = time.clock() - start_time
print('saving took {:f} seconds'.format(delta_time))
print('tf.version.GIT_VERSION={}'.format(tf.version.GIT_VERSION))
print('tf.version.VERSION={}'.format(tf.version.VERSION))
Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
saving took 3942.697280 seconds
tf.version.GIT_VERSION=v2.3.0-rc2-23-gb36436b087
tf.version.VERSION=2.3.0
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 27 (15 by maintainers)
Ah, batching multiple tangents associated with one primal? A GSoC student started on built-in
tf.vectorized_map
integration, but unfortunately it’s not complete (e.g. won’t work for higher-order yet because its tf.function integration has issues).You can use it with tf.vectorized_map yourself for sure: https://github.com/tensorflow/tensorflow/blob/3da9cc88d244c415a6ac1c0b81fb0983809829b0/tensorflow/python/eager/forwardprop_test.py#L93-L97
One issue is that
tf.vectorized_map
traces the function at the moment rather than executing it eagerly.You can also run forwardprop in a Python for loop if there aren’t many tangents. That’d happen eagerly.
I think the short answer is that
tf.function
produces function call operations, and the gradients/jvps of call operations are more complicated than the equivalent eager operations, especially if you’re nesting to take higher-order gradients. Usually graph optimizations re-simplify, but apparently not in this case. There’s a plan to unify these and always do the thing we do in eager right now, but it won’t happen very quickly.I see why you get that error from experimental_compile with
Issue_fwd.f
but notff
decorated. It looks like it was added in https://github.com/tensorflow/tensorflow/commit/6a6261c0a0e803891af95f5e754180739df1897d@yunxing do we need to register a gradient for xla_dynamic_update_slice? Is there a bug for it?