tensorflow: Raise ValueError when saving a model created in mirroredstrategy
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Yes
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): tf-nightly
- Python version: 3.6
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture script You can also obtain the TensorFlow version with:
- TF 1.0:
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" - TF 2.0:
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the current behavior
The model is created inside a mirroredstrategy. When I save the model using model.save(save_path) after training, it raises ValueError: SyncOnReadVariable does not support assign_addin cross-replica context when aggregation is set totf.VariableAggregation.SUM. The error is triggered here. The related error tracing is:
File "ncf_keras_main.py", line 85, in call
self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1678, in add_metric
metric_obj(value)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 231, in __call__
replica_local_fn, *args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py", line 1133, in call_replica_local_fn
return fn(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 211, in replica_local_fn
update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py", line 90, in decorated
update_op = update_state_fn(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 176, in update_state_fn
return ag_update_state(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
return func(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 373, in update_state
update_total_op = self.total.assign_add(value_sum)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/distribute/values.py", line 918, in assign_add
"SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.
I also attached a complete tracing for your reference.
Describe the expected behavior
Standalone code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Traceback (most recent call last):
File "ncf_keras_main.py", line 568, in <module>
app.run(main)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "ncf_keras_main.py", line 563, in main
logging.info("Result is %s", run_ncf(FLAGS))
File "ncf_keras_main.py", line 351, in run_ncf
keras_model.save("save_model")
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1950, in save
signatures, options)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py", line 134, in save_model
signatures, options)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 78, in save
save_lib.save(model, filepath, signatures, options)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 953, in save
obj, export_dir, signatures, options, meta_graph_def)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 1015, in _build_meta_graph
checkpoint_graph_view)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 75, in find_function_to_export
functions = saveable_view.list_functions(saveable_view.root)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/saved_model/save.py", line 144, in list_functions
self._serialization_cache)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 2543, in _list_functions_for_serialization
Model, self)._list_functions_for_serialization(serialization_cache)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3014, in _list_functions_for_serialization
.list_functions_for_serialization(serialization_cache))
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 87, in list_functions_for_serialization
fns = self.functions_to_serialize(serialization_cache)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 77, in functions_to_serialize
serialization_cache).functions_to_serialize)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 92, in _get_serialized_attributes
serialization_cache)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 51, in _get_serialized_attributes_internal
default_signature = save_impl.default_save_signature(self.obj)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 205, in default_save_signature
fn.get_concrete_function()
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1168, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1074, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
*args, **kwds))
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2842, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3200, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3062, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 979, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py", line 132, in _wrapped_model
outputs = model(inputs, training=False)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 961, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py", line 385, in call
inputs, training=training, mask=mask)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py", line 507, in _run_internal_graph
outputs = node.layer(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 961, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
return func(*args, **kwargs)
File "ncf_keras_main.py", line 85, in call
self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1678, in add_metric
metric_obj(value)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 231, in __call__
replica_local_fn, *args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/distribute/distributed_training_utils.py", line 1133, in call_replica_local_fn
return fn(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 211, in replica_local_fn
update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/utils/metrics_utils.py", line 90, in decorated
update_op = update_state_fn(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 176, in update_state_fn
return ag_update_state(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
return func(*args, **kwargs)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/metrics.py", line 373, in update_state
update_total_op = self.total.assign_add(value_sum)
File "/home/ml/users/jdong25/.conda/envs/tf/lib/python3.7/site-packages/tensorflow/python/distribute/values.py", line 918, in assign_add
"SyncOnReadVariable does not support `assign_add` in "
ValueError: SyncOnReadVariable does not support `assign_add` in cross-replica context when aggregation is set to `tf.VariableAggregation.SUM`.
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 16 (2 by maintainers)
This seems to be a problem also with:
tf.keras.metrics.Mean()Using TF 2.5.0 and get the same error with MirroredStrategy
Facing the same issue with Tensorflow 2.8.0 using
metric = tfk.metrics.MeanwithMirroredStrategywhen callingmetric.update_state(value). This is a basic functionality for the Keras echo system. Strange that this is a persisting bug for so many versions.I’m using TensorFlow 2.9.1, and I found a workaround. Do not define the metrics within the scope.
This works when you are using custom training loop. I didn’t test on
Model#fit()because I have no enough time.The model should not be compiled within the
MirroredStrategyscope.