tensorflow: TPUStrategy.run fails on non-primary thread with "No OpKernel was registered", "TPUReplicatedInput"
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes, included here
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): (a) Google Collab, (b) Debian 10 (Buster), Google Cloud VM TF2.3.1 image
- TensorFlow installed from (source or binary): (a) Google Collab, (b) Google Cloud VM TF2.3.1 image
- TensorFlow version (use command below): (a) Google Collab, (b) v2.3.0-54-gfcc4b966f1 2.3.1
- Python version: (a) Google Collab, (b) 3.7.3 (Google Cloud VM TF2.3.1 image)
- GPU model and memory: (a) Google Collab TPU, (b) Google Cloud TPU, us-central1-f, v2-8, 2.3.1
Describe the current behavior
Running TPUStrategy.run on a secondary thread crashes
Describe the expected behavior
Running TPUStrategy.run on a secondary thread behaves the same as on the primary thread, succeeding.
This is breaking for applications like inference from AlphaZero-style MCTS self-play originating from C++, using TPUStrategy.experimental_distribute_dataset, where either the primary C++ thread is required to perform other housekeeping, or where multiple prediction threads help to keep the inference pipeline more full. However, this minimal Python-only example gives the same error. The error differs slightly when experimental_distribute_dataset is used, but the final “Op:__inference_tpu_function_177” is common.
Standalone code to reproduce the issue
This snippet runs in a Google Collab notebook:
import tensorflow as tf
import threading
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
@tf.function
def double(x):
return x * 2.0
def test():
input = tf.range(5, dtype=tf.float32)
strategy.run(double, args=(input,))
test()
print("TPUStrategy.run works on primary thread")
thread = threading.Thread(target=test)
thread.start()
thread.join()
Other info / logs
Exception in thread Thread-5: Traceback (most recent call last): File “/usr/lib/python3.6/threading.py”, line 916, in _bootstrap_inner self.run() File “/usr/lib/python3.6/threading.py”, line 864, in run self._target(*self._args, **self._kwargs) File “<ipython-input-2-c35bbe6196e2>”, line 15, in test strategy.run(double, args=(input,)) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py”, line 279, in run return self.extended.tpu_run(fn, args, kwargs, options) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py”, line 1095, in tpu_run return func(args, kwargs) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py”, line 780, in call result = self._call(*args, **kwds) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py”, line 814, in _call results = self._stateful_fn(*args, **kwds) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 2829, in call return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1848, in _filtered_call cancellation_manager=cancellation_manager) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1924, in _call_flat ctx, args, cancellation_manager=cancellation_manager)) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 550, in call ctx=ctx) File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py”, line 60, in quick_execute inputs, attrs, num_outputs) tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op ‘TPUReplicatedInput’ used by {{node input0}} with these attrs: [T=DT_INT32, index=0, is_mirrored_variable=false, N=8, is_packed=false] Registered devices: [CPU, TPU, TPU_SYSTEM, XLA_CPU] Registered kernels: <no registered kernels>
[[input0]] [Op:__inference_tpu_function_177]
About this issue
- Original URL
- State: open
- Created 4 years ago
- Comments: 18 (11 by maintainers)
Hi @mohantym - in your 2.9 and 2.11 gists you’re actually pasting the workaround, not the original repro.
Please see my comment on April 23rd, above: