tensorflow: Model can't be checkpointed with Keras+MultiworkerMirroredStrategy

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

Have I written custom code (as opposed to using a stock example script provided in TensorFlow): OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Darwin-18.0.0-x86_64-i386-64bit Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: TensorFlow installed from (source or binary): binary TensorFlow version (use command below): 2.0.0-beta1 Python version:3.6.8 Bazel version (if compiling from source): GCC/Compiler version (if compiling from source): CUDA/cuDNN version: GPU model and memory: You can collect some of this information using our environment capture script You can also obtain the TensorFlow version with: 1. TF 1.0: python -c “import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)” 2. TF 2.0: python -c “import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)”

Describe the current behavior This is a followup for https://github.com/tensorflow/tensorflow/issues/31070

I tried following 2 solutions 1.I applied 6345ad5 to my tensorflow installed code 2. I install latest nightly dev build

both gave me following error, seems though previous commit change data type to int64, somewhere else still expects int32

2019-08-01 22:41:51.971726: W tensorflow/core/framework/op_kernel.cc:1546] OP_REQUIRES failed at collective_ops.cc:354 : Internal: RecvBufResponse returned 8 bytes where to_tensor expected 4
Traceback (most recent call last):
  File "example_tf2.py", line 124, in <module>
    steps_per_epoch = parallel_steps)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training.py", line 643, in fit
    use_multiprocessing=use_multiprocessing)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_distributed.py", line 776, in wrapper
    mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/distribute_coordinator.py", line 853, in run_distribute_coordinator
    task_id, session_config, rpc_layer)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/distribute_coordinator.py", line 360, in _run_single_worker
    return worker_fn(strategy)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_distributed.py", line 771, in _worker_fn
    return fn(instance, model, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_distributed.py", line 681, in fit
    steps_name='steps_per_epoch')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_arrays.py", line 294, in model_iteration
    batch_outs = f(actual_inputs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/distribute/distributed_training_utils.py", line 813, in execution_function
    return [out.numpy() for out in distributed_function(input_fn)]
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/def_function.py", line 416, in __call__
    self._initialize(args, kwds, add_initializers_to=initializer_map)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/def_function.py", line 359, in _initialize
    *args, **kwds))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/function.py", line 1360, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/function.py", line 1648, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/function.py", line 1541, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/func_graph.py", line 716, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/def_function.py", line 309, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/func_graph.py", line 706, in wrapper
    raise e.ag_error_metadata.to_exception(type(e))
tensorflow.python.autograph.impl.api.StagingError: in converted code:

    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/distribute/distributed_training_utils.py:804 distributed_function  *
        outputs = strategy.experimental_run_v2(
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:708 experimental_run_v2
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1710 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/mirrored_strategy.py:708 _call_for_each_replica
        fn, args, kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/mirrored_strategy.py:195 _call_for_each_replica
        coord.join(threads)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py:389 join
        six.reraise(*self._exc_info_to_raise)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py:297 stop_on_exception
        yield
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/mirrored_strategy.py:926 run
        self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training.py:908 train_on_batch
        output_loss_metrics=self._output_loss_metrics)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_eager.py:307 train_on_batch
        output_loss_metrics=output_loss_metrics))
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_eager.py:260 _process_single_batch
        model.trainable_weights))
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:434 apply_gradients
        self._create_hypers()
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:608 _create_hypers
        aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:770 add_weight
        aggregation=aggregation)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/training/tracking/base.py:713 _add_variable_with_custom_getter
        **kwargs_for_getter)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/base_layer_utils.py:154 make_variable
        shape=variable_shape if variable_shape else None)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variables.py:221 _variable_v1_call
        shape=shape)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variables.py:60 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/shared_variable_creator.py:69 create_new_variable
        v = next_creator(*args, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variables.py:60 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1250 creator_with_resource_vars
        return self._create_variable(*args, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/collective_all_reduce_strategy.py:368 _create_variable
        _real_mirrored_creator, *args, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/mirrored_strategy.py:251 _create_mirrored_variable
        value_list = real_mirrored_creator(devices, *args, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/collective_all_reduce_strategy.py:355 _real_mirrored_creator
        v = next_creator(*args, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variables.py:60 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/def_function.py:347 variable_capturing_scope
        lifted_initializer_graph=lifted_initializer_graph, **kwds)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variables.py:264 __call__
        return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/eager/def_function.py:139 __init__
        initial_value() if init_from_fn else initial_value,
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/distribute/collective_all_reduce_strategy.py:330 _overridden_initial_value_fn
        group_key, collective_instance_key)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/collective_ops.py:161 broadcast_recv
        instance_key=instance_key)
    /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_collective_ops.py:66 collective_bcast_recv
        _six.raise_from(_core._status_to_exception(e.code, message), None)
    /root/.local/lib/python2.7/site-packages/six.py:737 raise_from
        raise value

    InternalError: RecvBufResponse returned 8 bytes where to_tensor expected 4 [Op:CollectiveBcastRecv]


Describe the expected behavior Keras model could be checkpoint-ed under multi worker training

Code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem.

from __future__ import absolute_import, division, print_function, unicode_literals
import datetime
import json
import os
import tensorflow_datasets as tfds
import tensorflow as tf
import subprocess
import shlex
import sys

tfds.disable_progress_bar()

BUFFER_SIZE = 60000
BATCH_SIZE = 64

NUM_WORKERS = 2
GLOBAL_BATCH_SIZE = NUM_WORKERS * BATCH_SIZE

if __name__ == "__main__":
  worker_addrs = ['localhost:9999', 'localhost:9998']
  os.environ['TF_CONFIG'] = json.dumps({
      'cluster': {
          'worker': worker_addrs,
      },
      'task': {'type': 'worker', 'index': int(sys.argv[1])}
  })

  strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
        metrics=['accuracy'])
    return model

  datasets, info = tfds.load(name='mnist',
                             with_info=True,
                             as_supervised=True)

  train_datasets_unbatched = datasets['train'].map(scale).shuffle(BUFFER_SIZE)

  train_datasets = train_datasets_unbatched.batch(GLOBAL_BATCH_SIZE)

  with strategy.scope():
    multi_worker_model = build_and_compile_cnn_model()

   
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath='/tmp/chk.hdf5',
        monitor='val_loss',
        save_best_only=True,
        load_weights_on_restart=True)

  multi_worker_model.fit(x=train_datasets, epochs=100, callbacks = [checkpoint_callback])

Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 21 (8 by maintainers)

Most upvoted comments

For those who are still stuck on this error, I was able to get rid of it by making sure every worker has the same set of callbacks. In my case, previously I only added the ModelCheckpoint callback on the first worker, e.g.

if is_master:
  callbacks.add(tf.keras.callbacks.ModelCheckpoint(...))

This caused the RecvBufResponse error in TF 2.2.0 and caused hanging in TF 2.4.0. Removing the is_master check solved the problem for me. IMO this requirement should be better documented, or at least the error message should be more descriptive.

from datetime import datetime
from packaging import version
import os
import tensorflow as tf
import numpy as np
import json
# Create a TensorBoard callback
logs = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs,
                                                 histogram_freq = 1,
                                                 profile_batch = '2048')
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["node67:12345", "node68:23456"]
    },
    'task': {'type': 'worker', 'index': 1}
})
def mnist_dataset(batch_size):
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    # The `x` arrays are in uint8 and have values in the range [0, 255].
    # We need to convert them to float32 with values in the range [0, 1]
    x_train = x_train / np.float32(255)
    y_train = y_train.astype(np.int64)
    train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size).prefetch(100)
    return train_dataset

def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(256, 2, activation='relu'),
      tf.keras.layers.Conv2D(128, 2, activation='relu'),
      tf.keras.layers.Conv2D(32, 1, activation='relu'),  
      tf.keras.layers.Conv2D(32, 2, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(2048, activation='relu'),        
      tf.keras.layers.Dense(1024, activation='relu'),        
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
    ])
    model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
    return model
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
num_workers = 2
per_worker_batch_size = 2048
# Here the batch size scales up by number of workers since 
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset, epochs=103, steps_per_epoch=70,callbacks = [tboard_callback])
@ @---------------------------------------------------------------------------
InternalError                             Traceback (most recent call last)
<ipython-input-7-402c6ad248bd> in <module>
     14 # number of steps per epoch. Note that the numbers here are for demonstration
     15 # purposes only and may not sufficiently produce a model with good quality.
---> 16 multi_worker_model.fit(multi_worker_dataset, epochs=103, steps_per_epoch=70,callbacks = [tboard_callback])

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     73         lambda _: method(self, *args, **kwargs),
     74         self.distribute_strategy,
---> 75         mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
     76 
     77   return tf_decorator.make_decorator(

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_coordinator.py in run_distribute_coordinator(worker_fn, strategy, eval_fn, eval_strategy, mode, cluster_spec, task_type, task_id, session_config, rpc_layer)
    851         # All jobs run `worker_fn` if between-graph.
    852         return _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
--> 853                                   task_id, session_config, rpc_layer)
    854       else:
    855         # Only one node runs `worker_fn` if in-graph.

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_coordinator.py in _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id, session_config, rpc_layer, worker_barrier, coord)
    358         return worker_fn(strategy)
    359     else:
--> 360       return worker_fn(strategy)
    361 
    362 

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in <lambda>(_)
     71 
     72     return dc.run_distribute_coordinator(
---> 73         lambda _: method(self, *args, **kwargs),
     74         self.distribute_strategy,
     75         mode=dc.CoordinatorMode.INDEPENDENT_WORKER)

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    828       self.stop_training = False
    829       train_function = self.make_train_function()
--> 830       callbacks.on_train_begin()
    831       # Handle fault-tolerance for multi-worker.
    832       # TODO(omalleyt): Fix the ordering issues that mean this has to

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py in on_train_begin(self, logs)
    445     logs = self._process_logs(logs)
    446     for callback in self.callbacks:
--> 447       callback.on_train_begin(logs)
    448 
    449   def on_train_end(self, logs=None):

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py in on_train_begin(self, logs)
   1948 
   1949   def on_train_begin(self, logs=None):
-> 1950     self._init_batch_steps()
   1951     if self._start_batch == 1:
   1952       self._enable_trace()

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py in _init_batch_steps(self)
   1893       self._total_batches_seen = {
   1894           self._train_run_name: variables.Variable(0, dtype='int64'),
-> 1895           self._validation_run_name: variables.Variable(0, dtype='int64')
   1896       }
   1897     else:

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/variables.py in __call__(cls, *args, **kwargs)
    259       return cls._variable_v1_call(*args, **kwargs)
    260     elif cls is Variable:
--> 261       return cls._variable_v2_call(*args, **kwargs)
    262     else:
    263       return super(VariableMetaclass, cls).__call__(*args, **kwargs)

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/variables.py in _variable_v2_call(cls, initial_value, trainable, validate_shape, caching_device, name, variable_def, dtype, import_scope, constraint, synchronization, aggregation, shape)
    253         synchronization=synchronization,
    254         aggregation=aggregation,
--> 255         shape=shape)
    256 
    257   def __call__(cls, *args, **kwargs):

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/variables.py in getter(**kwargs)
     64 
     65   def getter(**kwargs):
---> 66     return captured_getter(captured_previous, **kwargs)
     67 
     68   return getter

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py in creator_with_resource_vars(next_creator, **kwargs)
   1765         kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
   1766 
-> 1767       return self._create_variable(next_creator, **kwargs)
   1768 
   1769     def distributed_getter(getter, *args, **kwargs):

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_strategy.py in _create_variable(self, next_creator, **kwargs)
    608                                            _real_mirrored_creator,
    609                                            values.MirroredVariable,
--> 610                                            values.SyncOnReadVariable, **kwargs)
    611 
    612   def _validate_colocate_with_variable(self, colocate_with_variable):

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/distribute/values.py in create_mirrored_variable(strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs)
    692   # here.
    693   with tape.stop_recording():
--> 694     value_list = real_mirrored_creator(**kwargs)
    695     var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
    696     result = var_cls(strategy, value_list, aggregation)

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/distribute/mirrored_strategy.py in _real_mirrored_creator(**kwargs)
    600             # variable creation.
    601             with tape.stop_recording():
--> 602               v = next_creator(**kwargs)
    603           assert not isinstance(v, values.DistributedVariable)
    604           value_list.append(v)

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/variables.py in <lambda>(**kws)
    234                         shape=None):
    235     """Call on Variable class. Useful to force the signature."""
--> 236     previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
    237     for _, getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
    238       previous_getter = _make_getter(getter, previous_getter)

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py in default_variable_creator_v2(next_creator, **kwargs)
   2645       synchronization=synchronization,
   2646       aggregation=aggregation,
-> 2647       shape=shape)
   2648 
   2649 

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/variables.py in __call__(cls, *args, **kwargs)
    261       return cls._variable_v2_call(*args, **kwargs)
    262     else:
--> 263       return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    264 
    265 

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py in __init__(self, initial_value, trainable, collections, validate_shape, caching_device, name, dtype, variable_def, import_scope, constraint, distribute_strategy, synchronization, aggregation, shape)
   1432           aggregation=aggregation,
   1433           shape=shape,
-> 1434           distribute_strategy=distribute_strategy)
   1435 
   1436   def _init_from_args(self,

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py in _init_from_args(self, initial_value, trainable, collections, caching_device, name, dtype, constraint, synchronization, aggregation, distribute_strategy, shape)
   1565           with ops.name_scope("Initializer"), device_context_manager(None):
   1566             initial_value = ops.convert_to_tensor(
-> 1567                 initial_value() if init_from_fn else initial_value,
   1568                 name="initial_value", dtype=dtype)
   1569           if shape is not None:

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/distribute/collective_all_reduce_strategy.py in initial_value_fn()
    384                                                    initial_value.dtype,
    385                                                    group_size, group_key,
--> 386                                                    collective_instance_key)
    387           return initial_value
    388 

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/collective_ops.py in broadcast_recv(shape, dtype, group_size, group_key, instance_key, communication_hint)
    174       group_key=group_key,
    175       instance_key=instance_key,
--> 176       communication_hint=communication_hint.lower())

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/gen_collective_ops.py in collective_bcast_recv(T, group_size, group_key, instance_key, shape, communication_hint, name)
     55         pass  # Add nodes to the TensorFlow graph.
     56     except _core._NotOkStatusException as e:
---> 57       _ops.raise_from_not_ok_status(e, name)
     58   # Add nodes to the TensorFlow graph.
     59   T = _execute.make_type(T, "T")

~/miniconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6651   message = e.message + (" name: " + name if name is not None else "")
   6652   # pylint: disable=protected-access
-> 6653   six.raise_from(core._status_to_exception(e.code, message), None)
   6654   # pylint: enable=protected-access
   6655 

~/.local/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

InternalError: RecvBufResponse returned 4 bytes where to_tensor expected 8 [Op:CollectiveBcastRecv]

I had the Tensorboard configured only on one worker and thought I can show results from only one worker. After replicating Tensorboard callback code on other workers as well, the error goes away.

Thanks for providing this example @nimaaghli. I’m running this script in 2.2 but it seems to be training without problems, though I haven’t made it to the last epoch yet. Does this fail right away for you, or after a certain number of epochs? What TF version are you using?