tensorflow: Error when using batch renormalisation option under a strategy

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS: Windows 10
  • TensorFlow installed from (source or binary): tensorflow-gpu binary on pip
  • TensorFlow version (use command below): 1.14
  • Python version: 3.6.8
  • CUDA/cuDNN version:
  • GPU model and memory:

Describe the current behavior When I try to use renorm=True with tf.keras.layers.BatchNormalization under tf.distribute.MirroredStrategy I get the following error: handle is not available outside the replica context or a tf.distribute.Strategy.update() call.

Describe the expected behavior I can use batch renormalisation under mirrored strategy.

Code to reproduce the issue

import tensorflow as tf

strat = tf.distribute.MirroredStrategy()
with strat.scope():
  inp = tf.keras.Input((28, 28))
  tf.keras.layers.BatchNormalization(renorm=True)(inp)

Other info / logs

The following is the stack trace:

<ipython-input-3-d17cda041220> in <module>()
      2 with strat.scope():
      3   inp = tf.keras.Input((28, 28))
----> 4   tf.keras.layers.BatchNormalization(renorm=True)(inp)

16 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    632                     outputs = base_layer_utils.mark_as_return(outputs, acd)
    633                 else:
--> 634                   outputs = call_fn(inputs, *args, **kwargs)
    635 
    636             except TypeError as e:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in call(self, inputs, training)
    734       if self.renorm:
    735         r, d, new_mean, new_variance = self._renorm_correction_and_moments(
--> 736             new_mean, new_variance, training, inputs_size)
    737         # When training, the normalized values (say, x) will be transformed as
    738         # x * gamma + beta without renorm, and (x * r + d) * gamma + beta

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _renorm_correction_and_moments(self, mean, variance, training, inputs_size)
    598     new_mean = _update_renorm_variable(self.renorm_mean,
    599                                        self.renorm_mean_weight, mean,
--> 600                                        inputs_size)
    601     new_stddev = _update_renorm_variable(self.renorm_stddev,
    602                                          self.renorm_stddev_weight, stddev,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _update_renorm_variable(var, weight, value, inputs_size)
    593       def _fake_update():
    594         return array_ops.identity(var)
--> 595       return tf_utils.smart_cond(training, _do_update, _fake_update)
    596 
    597     # TODO(yuefengz): colocate the operations

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/tf_utils.py in smart_cond(pred, true_fn, false_fn, name)
     56         pred, true_fn=true_fn, false_fn=false_fn, name=name)
     57   return smart_module.smart_cond(
---> 58       pred, true_fn=true_fn, false_fn=false_fn, name=name)
     59 
     60 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     57   else:
     58     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
---> 59                                  name=name)
     60 
     61 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
   1975     try:
   1976       context_t.Enter()
-> 1977       orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
   1978       if orig_res_t is None:
   1979         raise ValueError("true_fn must have a return value.")

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py in BuildCondBranch(self, fn)
   1812     """Add the subgraph defined by fn() to the graph."""
   1813     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
-> 1814     original_result = fn()
   1815     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1816     if len(post_summaries) > len(pre_summaries):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _do_update()
    582           weight_value = array_ops.constant(1., dtype=weight.dtype)
    583         new_var = self._assign_moving_average(var, value, self.renorm_momentum,
--> 584                                               inputs_size)
    585         new_weight = self._assign_moving_average(weight, weight_value,
    586                                                  self.renorm_momentum,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _assign_moving_average(self, variable, value, momentum, inputs_size)
    447   def _assign_moving_average(self, variable, value, momentum, inputs_size):
    448     with K.name_scope('AssignMovingAvg') as scope:
--> 449       with ops.colocate_with(variable):
    450         decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
    451         if decay.dtype != variable.dtype.base_dtype:

/usr/lib/python3.6/contextlib.py in __enter__(self)
     79     def __enter__(self):
     80         try:
---> 81             return next(self.gen)
     82         except StopIteration:
     83             raise RuntimeError("generator didn't yield") from None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _colocate_with_for_gradient(self, op, gradient_uid, ignore_existing)
   4398   def _colocate_with_for_gradient(self, op, gradient_uid,
   4399                                   ignore_existing=False):
-> 4400     with self.colocate_with(op, ignore_existing):
   4401       if gradient_uid is not None and self._control_flow_context is not None:
   4402         self._control_flow_context.EnterGradientColocation(op, gradient_uid)

/usr/lib/python3.6/contextlib.py in __enter__(self)
     79     def __enter__(self):
     80         try:
---> 81             return next(self.gen)
     82         except StopIteration:
     83             raise RuntimeError("generator didn't yield") from None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in colocate_with(self, op, ignore_existing)
   4447       raise ValueError("Trying to reset colocation (op is None) but "
   4448                        "ignore_existing is not True")
-> 4449     op = _op_to_colocate_with(op)
   4450 
   4451     # By default, colocate_with resets the device function stack,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _op_to_colocate_with(v)
   6714   # happen soon, perhaps this hack to work around the circular
   6715   # import dependency is acceptable.
-> 6716   if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance(
   6717       v.handle.op, Operation):
   6718     return v.handle.op

/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/values.py in handle(self)
    641       device = distribute_lib.get_update_device()
    642       if device is None:
--> 643         raise ValueError("`handle` is not available outside the replica context"
    644                          " or a `tf.distribute.Strategy.update()` call.")
    645     return self.get(device=device).handle

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 2
  • Comments: 15 (5 by maintainers)

Most upvoted comments

still existing in tf 2.1

Could reproduce this issue with Tensorflow Version 1.14. Gist is in this link.

Thanks.

I am running into this issue as well. Any update on the fix?