tensorflow: TPUStrategy.run fails on non-primary thread with "No OpKernel was registered", "TPUReplicatedInput"

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes, included here
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): (a) Google Collab, (b) Debian 10 (Buster), Google Cloud VM TF2.3.1 image
  • TensorFlow installed from (source or binary): (a) Google Collab, (b) Google Cloud VM TF2.3.1 image
  • TensorFlow version (use command below): (a) Google Collab, (b) v2.3.0-54-gfcc4b966f1 2.3.1
  • Python version: (a) Google Collab, (b) 3.7.3 (Google Cloud VM TF2.3.1 image)
  • GPU model and memory: (a) Google Collab TPU, (b) Google Cloud TPU, us-central1-f, v2-8, 2.3.1

Describe the current behavior

Running TPUStrategy.run on a secondary thread crashes

Describe the expected behavior

Running TPUStrategy.run on a secondary thread behaves the same as on the primary thread, succeeding.

This is breaking for applications like inference from AlphaZero-style MCTS self-play originating from C++, using TPUStrategy.experimental_distribute_dataset, where either the primary C++ thread is required to perform other housekeeping, or where multiple prediction threads help to keep the inference pipeline more full. However, this minimal Python-only example gives the same error. The error differs slightly when experimental_distribute_dataset is used, but the final “Op:__inference_tpu_function_177” is common.

Standalone code to reproduce the issue

This snippet runs in a Google Collab notebook:

import tensorflow as tf
import threading

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

@tf.function
def double(x):
  return x * 2.0

def test():
  input = tf.range(5, dtype=tf.float32)
  strategy.run(double, args=(input,))

test()
print("TPUStrategy.run works on primary thread")

thread = threading.Thread(target=test)
thread.start()
thread.join()

Other info / logs

Exception in thread Thread-5: Traceback (most recent call last): File “/usr/lib/python3.6/threading.py”, line 916, in _bootstrap_inner self.run() File “/usr/lib/python3.6/threading.py”, line 864, in run self._target(*self._args, **self._kwargs) File “<ipython-input-2-c35bbe6196e2>”, line 15, in test strategy.run(double, args=(input,)) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py”, line 279, in run return self.extended.tpu_run(fn, args, kwargs, options) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py”, line 1095, in tpu_run return func(args, kwargs) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py”, line 780, in call result = self._call(*args, **kwds) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py”, line 814, in _call results = self._stateful_fn(*args, **kwds) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 2829, in call return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1848, in _filtered_call cancellation_manager=cancellation_manager) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1924, in _call_flat ctx, args, cancellation_manager=cancellation_manager)) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 550, in call ctx=ctx) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py”, line 60, in quick_execute inputs, attrs, num_outputs) tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op ‘TPUReplicatedInput’ used by {{node input0}} with these attrs: [T=DT_INT32, index=0, is_mirrored_variable=false, N=8, is_packed=false] Registered devices: [CPU, TPU, TPU_SYSTEM, XLA_CPU] Registered kernels: <no registered kernels>

 [[input0]] [Op:__inference_tpu_function_177]

About this issue

  • Original URL
  • State: open
  • Created 4 years ago
  • Comments: 18 (11 by maintainers)

Most upvoted comments

Hi @mohantym - in your 2.9 and 2.11 gists you’re actually pasting the workaround, not the original repro.

Please see my comment on April 23rd, above:

I think there’s been a misunderstanding here.

A change to the test function has been given as a workaround, but is not a fix to the bug. I’m not aware of any TF documentation that says, on a new thread, you need to run the magic line, with tf.device("/job:worker"):

Please see allenlavoie’s response above:

The device stack is thread local so multiple threads can manage their device stacks independently. I don’t think anyone has considered the inheritance behavior on launching a new thread; it does seem like inheriting the device stack from the parent context would be nice, but I don’t know how feasible that would be. Sounds like a reasonable request at least.