tensorflow: Error when using batch renormalisation option under a strategy
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
- OS: Windows 10
- TensorFlow installed from (source or binary): tensorflow-gpu binary on pip
- TensorFlow version (use command below): 1.14
- Python version: 3.6.8
- CUDA/cuDNN version:
- GPU model and memory:
Describe the current behavior
When I try to use renorm=True with tf.keras.layers.BatchNormalization under tf.distribute.MirroredStrategy I get the following error: handle is not available outside the replica context or a tf.distribute.Strategy.update() call.
Describe the expected behavior I can use batch renormalisation under mirrored strategy.
Code to reproduce the issue
import tensorflow as tf
strat = tf.distribute.MirroredStrategy()
with strat.scope():
inp = tf.keras.Input((28, 28))
tf.keras.layers.BatchNormalization(renorm=True)(inp)
Other info / logs
The following is the stack trace:
<ipython-input-3-d17cda041220> in <module>()
2 with strat.scope():
3 inp = tf.keras.Input((28, 28))
----> 4 tf.keras.layers.BatchNormalization(renorm=True)(inp)
16 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
632 outputs = base_layer_utils.mark_as_return(outputs, acd)
633 else:
--> 634 outputs = call_fn(inputs, *args, **kwargs)
635
636 except TypeError as e:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in call(self, inputs, training)
734 if self.renorm:
735 r, d, new_mean, new_variance = self._renorm_correction_and_moments(
--> 736 new_mean, new_variance, training, inputs_size)
737 # When training, the normalized values (say, x) will be transformed as
738 # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _renorm_correction_and_moments(self, mean, variance, training, inputs_size)
598 new_mean = _update_renorm_variable(self.renorm_mean,
599 self.renorm_mean_weight, mean,
--> 600 inputs_size)
601 new_stddev = _update_renorm_variable(self.renorm_stddev,
602 self.renorm_stddev_weight, stddev,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _update_renorm_variable(var, weight, value, inputs_size)
593 def _fake_update():
594 return array_ops.identity(var)
--> 595 return tf_utils.smart_cond(training, _do_update, _fake_update)
596
597 # TODO(yuefengz): colocate the operations
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/tf_utils.py in smart_cond(pred, true_fn, false_fn, name)
56 pred, true_fn=true_fn, false_fn=false_fn, name=name)
57 return smart_module.smart_cond(
---> 58 pred, true_fn=true_fn, false_fn=false_fn, name=name)
59
60
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
57 else:
58 return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
---> 59 name=name)
60
61
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
505 'in a future version' if date is None else ('after %s' % date),
506 instructions)
--> 507 return func(*args, **kwargs)
508
509 doc = _add_deprecated_arg_notice_to_docstring(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
1975 try:
1976 context_t.Enter()
-> 1977 orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
1978 if orig_res_t is None:
1979 raise ValueError("true_fn must have a return value.")
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py in BuildCondBranch(self, fn)
1812 """Add the subgraph defined by fn() to the graph."""
1813 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
-> 1814 original_result = fn()
1815 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
1816 if len(post_summaries) > len(pre_summaries):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _do_update()
582 weight_value = array_ops.constant(1., dtype=weight.dtype)
583 new_var = self._assign_moving_average(var, value, self.renorm_momentum,
--> 584 inputs_size)
585 new_weight = self._assign_moving_average(weight, weight_value,
586 self.renorm_momentum,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/normalization.py in _assign_moving_average(self, variable, value, momentum, inputs_size)
447 def _assign_moving_average(self, variable, value, momentum, inputs_size):
448 with K.name_scope('AssignMovingAvg') as scope:
--> 449 with ops.colocate_with(variable):
450 decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
451 if decay.dtype != variable.dtype.base_dtype:
/usr/lib/python3.6/contextlib.py in __enter__(self)
79 def __enter__(self):
80 try:
---> 81 return next(self.gen)
82 except StopIteration:
83 raise RuntimeError("generator didn't yield") from None
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _colocate_with_for_gradient(self, op, gradient_uid, ignore_existing)
4398 def _colocate_with_for_gradient(self, op, gradient_uid,
4399 ignore_existing=False):
-> 4400 with self.colocate_with(op, ignore_existing):
4401 if gradient_uid is not None and self._control_flow_context is not None:
4402 self._control_flow_context.EnterGradientColocation(op, gradient_uid)
/usr/lib/python3.6/contextlib.py in __enter__(self)
79 def __enter__(self):
80 try:
---> 81 return next(self.gen)
82 except StopIteration:
83 raise RuntimeError("generator didn't yield") from None
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in colocate_with(self, op, ignore_existing)
4447 raise ValueError("Trying to reset colocation (op is None) but "
4448 "ignore_existing is not True")
-> 4449 op = _op_to_colocate_with(op)
4450
4451 # By default, colocate_with resets the device function stack,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _op_to_colocate_with(v)
6714 # happen soon, perhaps this hack to work around the circular
6715 # import dependency is acceptable.
-> 6716 if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance(
6717 v.handle.op, Operation):
6718 return v.handle.op
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/values.py in handle(self)
641 device = distribute_lib.get_update_device()
642 if device is None:
--> 643 raise ValueError("`handle` is not available outside the replica context"
644 " or a `tf.distribute.Strategy.update()` call.")
645 return self.get(device=device).handle
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 2
- Comments: 15 (5 by maintainers)
still existing in tf 2.1
Could reproduce this issue with Tensorflow Version 1.14. Gist is in this link.
Thanks.
I am running into this issue as well. Any update on the fix?