jax: Colab TPU setup fails with nightly driver
Please:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.local_devices()
- If applicable, include full error messages/tracebacks.
RuntimeError Traceback (most recent call last)
<ipython-input-2-abdb42bb8a9a> in <module>()
3 import jax.tools.colab_tpu
4 jax.tools.colab_tpu.setup_tpu()
----> 5 jax.local_devices()
2 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/lib/xla_bridge.py in _get_backend_uncached(platform)
246 if backend is None:
247 if platform in _backends_errors:
--> 248 raise RuntimeError(f"Requested backend {platform}, but it failed "
249 f"to initialize: {_backends_errors[platform]}")
250 raise RuntimeError(f"Unknown backend {platform}")
RuntimeError: Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED: Failed to connect to remote server at address: grpc://10.91.25.74:8470. Error from gRPC: Deadline Exceeded. Details:
On the other hand, using an earlier version of driver in setup_tpu
does work, e.g.,
def setup_tpu():
"""Sets up Colab to run on TPU.
Note: make sure the Colab Runtime is set to Accelerator: TPU.
"""
global TPU_DRIVER_MODE
if not TPU_DRIVER_MODE:
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1-dev20211030'
requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 15
@rog77 replace
with
I am having the same issue.
import requests import jax import os import numpy as np
from jax.config import config
TPU_DRIVER_MODE = 0
def setup_tpu(): “”“Sets up Colab to run on TPU. Note: make sure the Colab Runtime is set to Accelerator: TPU. “”” global TPU_DRIVER_MODE
if not TPU_DRIVER_MODE: colab_tpu_addr = os.environ[‘COLAB_TPU_ADDR’].split(‘😂[0] url = f’http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1-dev20211030’ requests.post(url) TPU_DRIVER_MODE = 1
config.FLAGS.jax_xla_backend = “tpu_driver” config.FLAGS.jax_backend_target = “grpc://” + os.environ[‘COLAB_TPU_ADDR’]
setup_tpu()
jax.local_devices()
^^
This does indeed give me:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
On Colab.
Thank you, @luyug!