tensorflow: experimental_run_v2 throws AttributeError with MultiWorkerMirroredStrategy
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.1 and 2.2 affected
Describe the current behavior
Using strategy.experimental_run_v2 (or strategy.run for TF 2.2) with MultiWorkerMirroredStrategy throws AttributeError: 'CollectiveAllReduceExtended' object has no attribute '_cfer_fn_cache' when passing it a tf.function
This is caused by the access at https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/distribute/mirrored_strategy.py#L743 due to CollectiveAllReduceExtended not calling the super().__init__ function which creates that dictionary at https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/distribute/mirrored_strategy.py#L472
I noted that the relevant code was removed in current master by https://github.com/tensorflow/tensorflow/commit/b16d24a342c5de1384dcb9ee408a74f206d332b2 but wanted to make sure it is included in the next release or in a patch release. Also MultiWorkerMirroredStrategy is not mentioned in the commit, so it might be a good idea to include something like this as a test case to avoid regressions. Looking at the commit I guess this is fixed too.
Standalone code to reproduce the issue
import tensorflow as tf
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10),
])
model.compile(
optimizer=tf.keras.optimizers.SGD(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
@tf.function
def train_step(model, data, target):
with tf.GradientTape() as tape:
predictions = model(data, training=True)
loss = model.loss(target, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
def distributed_train_step(strategy, model, x, y):
strategy.experimental_run_v2(train_step, args=(model, x, y))
for x, y in train_dataset:
distributed_train_step(strategy, model, x, y)
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 15 (8 by maintainers)
We have recently added this notebook: https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX
The approach should be similar with MirroredStrategy, though there may be some gymnastics needed to avoid a merge_call since ReplicaContext.all_reduce puts one in unconditionally.
On Fri, May 14, 2021, 6:02 AM VideoRec @.***> wrote: