tensorflow: add_update in cross-replica mode is broken (BatchNormalization layer impossible to use)

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): no
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Archlinux
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v1.12.1-3374-g9eb67b17bf 2.0.0-dev20190605
  • Python version: 3.6
  • CUDA/cuDNN version: 10
  • GPU model and memory: 1080 Ti

Describe the current behavior

I expect to do a forward pass with a model with a BachNormalization layer in training mode, when using the tf.distribuite.MirroredStrategy but I can’t, because it reises the following exception:

RuntimeError: add_update was called in a cross-replica context. This is not expected. If you require this feature, please file an issue.

Why it is not expected?

Describe the expected behavior

It should work. The commit that introduced this behavior is: https://github.com/tensorflow/tensorflow/commit/316cd57883166e6a0b4c2d0eaacebddad7675b39#diff-8eb7e20502209f082d0cb15119a50413

Code to reproduce the issue

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(10),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(1),
    ]
)

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    out = model(tf.zeros((1, 10)), training=True)

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 4
  • Comments: 38 (14 by maintainers)

Most upvoted comments

I confirm that this issue is fixed on tf-nightly v2.1.0-dev20200104.

Same issue for me!

It only occurs if I use tf.keras.callbacks.ModelCheckpoint, no problem otherwise…