jax: TPU not working
Description
I created a new TPU VM and sucessfully used it but after stopping / starting the VM am getting the following error:
Python 3.8.10 (default, Jun 22 2022, 20:18:18)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
Traceback (most recent call last):
File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 333, in backends
backend = _init_backend(platform)
File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 385, in _init_backend
backend = factory()
File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 191, in tpu_client_timer_callback
client = xla_client.make_tpu_client()
File "/home/chris/.local/lib/python3.8/site-packages/jaxlib/xla_client.py", line 122, in make_tpu_client
return _xla.get_tpu_client(
jaxlib.xla_extension.XlaRuntimeError: NOT_FOUND: No ba16c7433 device found.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 483, in devices
return get_backend(backend).devices()
File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 425, in get_backend
return _get_backend_uncached(platform)
File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 409, in _get_backend_uncached
bs = backends()
File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 350, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': NOT_FOUND: No ba16c7433 device found. (set JAX_PLATFORMS='' to automatically choose an available backend)
What jax/jaxlib version are you using?
jax 0.3.25[tpu], jaxlib 0.3.25
Which accelerator(s) are you using?
TPU
Additional system info
No response
NVIDIA GPU info
No response
About this issue
- Original URL
- State: open
- Created 2 years ago
- Comments: 37 (6 by maintainers)
Ah, yes only one framework can use the TPU at a time, so if you import tensorflow before jax, jax won’t be able to access the TPU. I suggest uninstalling tensorflow and reinstalling the CPU-only version, something like:
!pip uninstall -y tensorflow
!pip install tensorflow-cpu
That way you can use tensorflow for non-TPU functions like tf.data, and jax for the TPU parts.
@gauravbrills I’m not sure this workaround will work for Keras NLP. I believe there are some preliminary discussions about allowing multiple frameworks to access the TPU concurrently, but I don’t have a timeline for this yet. I’m not very familiar with Keras – are you trying to use Keras with a JAX backend?
@skye @mattjj I faced the same issue today on Kaggle. There is an interesting thing that I found though. If you are importing TensorFlow before JAX, then somehow JAX can’t initialize the TPU runtime.
I am sure that if I remove everything related to TF in the code, it would work. The problem is that we won’t be able to use
tfds
for loading and processing datasets. Let me know if you need any more info from my sideUpdate: I figured the root cause at least for Kaggle Notebooks. Depending on who access the TPU first (TensorFlow or JAX), the other won’t be able to initialize and use it. The bug is definitely related to TF as I am unable to hid the TPU system from the list of available devices. Gonna open an issue in TF repo
@defdet this appears to have worked. Thanks! A little confusing because
ubuntu-2204-base
is not listed as a software option in the console, but works when executed from the command line. Made me think it was not a supported software option for v5.Same here. I don’t know if it worked before I updated/upgraded the APT packages and rebooted the machine, but now I get:
and there is no device either:
@zeyuwang615 the process ID of the process using the TPU is included in the error message (that’s how @ayaka14732 knew to kill pid
137872
from your above example). You can also usesudo lsof -w /dev/accel0
on the command line to find processes using /dev/accel0 (one of the TPU chips).Please file a new issue if you have more questions. I’d like to keep this thread focused on the missing
/dev/accel0
driver.Hello, I seem to be running into the same issue. I did not stop/start the TPU VM though. It just seems my experiment stopped running and when I try to run it again, I have the following error, when it was working just fine before:
Yes, creating a new VM seems to work until we stop / start the instance. We’re using this as a workaround until this issue can be resolved.