tensorflow: Raise ValueError when saving a model created in mirroredstrategy

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Yes
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): tf-nightly
  • Python version: 3.6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

You can collect some of this information using our environment capture script You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior The model is created inside a mirroredstrategy. When I save the model using model.save(save_path) after training, it raises ValueError: SyncOnReadVariable does not support assign_addin cross-replica context when aggregation is set totf.VariableAggregation.SUM. The error is triggered here. The related error tracing is:

  File "ncf_keras_main.py", line 85, in call
    self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1678, in add_metric
    metric_obj(value)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 231, in __call__
    replica_local_fn, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py", line 1133, in call_replica_local_fn
    return fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 211, in replica_local_fn
    update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py", line 90, in decorated
    update_op = update_state_fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 176, in update_state_fn
    return ag_update_state(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 373, in update_state
    update_total_op = self.total.assign_add(value_sum)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/distribute/values.py", line 918, in assign_add
    "SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.

I also attached a complete tracing for your reference.

Describe the expected behavior

Standalone code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.

Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

Traceback (most recent call last):
  File "ncf_keras_main.py", line 568, in <module>
    app.run(main)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "ncf_keras_main.py", line 563, in main
    logging.info("Result is %s", run_ncf(FLAGS))
  File "ncf_keras_main.py", line 351, in run_ncf
    keras_model.save("save_model")
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1950, in save
    signatures, options)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py", line 134, in save_model
    signatures, options)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 78, in save
    save_lib.save(model, filepath, signatures, options)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 953, in save
    obj, export_dir, signatures, options, meta_graph_def)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1015, in _build_meta_graph
    checkpoint_graph_view)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 144, in list_functions
    self._serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 2543, in _list_functions_for_serialization
    Model, self)._list_functions_for_serialization(serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3014, in _list_functions_for_serialization
    .list_functions_for_serialization(serialization_cache))
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 87, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 77, in functions_to_serialize
    serialization_cache).functions_to_serialize)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 92, in _get_serialized_attributes
    serialization_cache)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 51, in _get_serialized_attributes_internal
    default_signature = save_impl.default_save_signature(self.obj)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 205, in default_save_signature
    fn.get_concrete_function()
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1168, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1074, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2842, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3200, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3062, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 979, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py", line 132, in _wrapped_model
    outputs = model(inputs, training=False)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 961, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py", line 385, in call
    inputs, training=training, mask=mask)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py", line 507, in _run_internal_graph
    outputs = node.layer(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 961, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)
  File "ncf_keras_main.py", line 85, in call
    self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1678, in add_metric
    metric_obj(value)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 231, in __call__
    replica_local_fn, *args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py", line 1133, in call_replica_local_fn
    return fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 211, in replica_local_fn
    update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py", line 90, in decorated
    update_op = update_state_fn(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 176, in update_state_fn
    return ag_update_state(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 373, in update_state
    update_total_op = self.total.assign_add(value_sum)
  File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/distribute/values.py", line 918, in assign_add
    "SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 16 (2 by maintainers)

Most upvoted comments

This seems to be a problem also with: tf.keras.metrics.Mean()

Using TF 2.5.0 and get the same error with MirroredStrategy

Facing the same issue with Tensorflow 2.8.0 using metric = tfk.metrics.Mean with MirroredStrategy when calling metric.update_state(value). This is a basic functionality for the Keras echo system. Strange that this is a persisting bug for so many versions.

I’m using TensorFlow 2.9.1, and I found a workaround. Do not define the metrics within the scope.

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = ...
    loss_fn = ...
    optimizer = ...

# Out of scope, this works.
metrics = tf.keras.metrics.Mean(name='total_loss')

This works when you are using custom training loop. I didn’t test on Model#fit() because I have no enough time.

The model should not be compiled within the MirroredStrategy scope.