jax: TPU not working

Description

I created a new TPU VM and sucessfully used it but after stopping / starting the VM am getting the following error:

Python 3.8.10 (default, Jun 22 2022, 20:18:18) 
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
Traceback (most recent call last):
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 333, in backends
    backend = _init_backend(platform)
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 385, in _init_backend
    backend = factory()
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 191, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/chris/.local/lib/python3.8/site-packages/jaxlib/xla_client.py", line 122, in make_tpu_client
    return _xla.get_tpu_client(
jaxlib.xla_extension.XlaRuntimeError: NOT_FOUND: No ba16c7433 device found.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 483, in devices
    return get_backend(backend).devices()
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 425, in get_backend
    return _get_backend_uncached(platform)
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 409, in _get_backend_uncached
    bs = backends()
  File "/home/chris/.local/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py", line 350, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': NOT_FOUND: No ba16c7433 device found. (set JAX_PLATFORMS='' to automatically choose an available backend)

What jax/jaxlib version are you using?

jax 0.3.25[tpu], jaxlib 0.3.25

Which accelerator(s) are you using?

TPU

Additional system info

No response

NVIDIA GPU info

No response

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Comments: 37 (6 by maintainers)

Most upvoted comments

Ah, yes only one framework can use the TPU at a time, so if you import tensorflow before jax, jax won’t be able to access the TPU. I suggest uninstalling tensorflow and reinstalling the CPU-only version, something like: !pip uninstall -y tensorflow !pip install tensorflow-cpu

That way you can use tensorflow for non-TPU functions like tf.data, and jax for the TPU parts.

@gauravbrills I’m not sure this workaround will work for Keras NLP. I believe there are some preliminary discussions about allowing multiple frameworks to access the TPU concurrently, but I don’t have a timeline for this yet. I’m not very familiar with Keras – are you trying to use Keras with a JAX backend?

@skye @mattjj I faced the same issue today on Kaggle. There is an interesting thing that I found though. If you are importing TensorFlow before JAX, then somehow JAX can’t initialize the TPU runtime.

I am sure that if I remove everything related to TF in the code, it would work. The problem is that we won’t be able to use tfds for loading and processing datasets. Let me know if you need any more info from my side

Update: I figured the root cause at least for Kaggle Notebooks. Depending on who access the TPU first (TensorFlow or JAX), the other won’t be able to initialize and use it. The bug is definitely related to TF as I am unable to hid the TPU system from the list of available devices. Gonna open an issue in TF repo

@defdet this appears to have worked. Thanks! A little confusing because ubuntu-2204-base is not listed as a software option in the console, but works when executed from the command line. Made me think it was not a supported software option for v5.

Same here. I don’t know if it worked before I updated/upgraded the APT packages and rebooted the machine, but now I get:

python3 -c "import jax; print(jax.local_devices())"
Traceback (most recent call last):
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 393, in backends
    backend = _init_backend(platform)
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 445, in _init_backend
    backend = factory()
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 187, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/vincent/charred/lib/python3.8/site-packages/jaxlib/xla_client.py", line 147, in make_tpu_client
    return _xla.get_tpu_client(
jaxlib.xla_extension.XlaRuntimeError: UNAVAILABLE: No TPU Platform available.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 575, in local_devices
    process_index = get_backend(backend).process_index()
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 485, in get_backend
    return _get_backend_uncached(platform)
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 469, in _get_backend_uncached
    bs = backends()
  File "/home/vincent/charred/lib/python3.8/site-packages/jax/_src/xla_bridge.py", line 410, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': UNAVAILABLE: No TPU Platform available. (set JAX_PLATFORMS='' to automatically choose an available backend)

and there is no device either:

ls -la /dev/accel*
ls: cannot access '/dev/accel*': No such file or directory

@zeyuwang615 the process ID of the process using the TPU is included in the error message (that’s how @ayaka14732 knew to kill pid 137872 from your above example). You can also use sudo lsof -w /dev/accel0 on the command line to find processes using /dev/accel0 (one of the TPU chips).

Please file a new issue if you have more questions. I’d like to keep this thread focused on the missing /dev/accel0 driver.

Hello, I seem to be running into the same issue. I did not stop/start the TPU VM though. It just seems my experiment stopped running and when I try to run it again, I have the following error, when it was working just fine before:

  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 333, in backends
    backend = _init_backend(platform)
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 385, in _init_backend
    backend = factory()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 191, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/melisande/.local/lib/python3.9/site-packages/jaxlib/xla_client.py", line 122, in make_tpu_client
    return _xla.get_tpu_client(
jaxlib.xla_extension.XlaRuntimeError: PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:    "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/melisande/scenic-glc/scenic/projects/glc/main.py", line 44, in <module>
    devices = jax.local_devices()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 515, in local_devices
    process_index = get_backend(backend).process_index()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 425, in get_backend
    return _get_backend_uncached(platform)
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 409, in _get_backend_uncached
    bs = backends()
  File "/home/melisande/.local/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py", line 350, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': PERMISSION_DENIED: open(/dev/accel0): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel0; Unable to create Node RegisterInterface for node 0, config: device_path:        "/dev/accel0" mode: KERNEL debug_data_directory: "" dump_anomalies_only: true crash_in_debug_dump: false allow_core_dump: true; could not create driver instance (set JAX_PLATFORMS='' to automatically choose an available backend)```

Yes, creating a new VM seems to work until we stop / start the instance. We’re using this as a workaround until this issue can be resolved.