jax: Colab TPU setup fails with nightly driver

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
jax.local_devices()
  • If applicable, include full error messages/tracebacks.
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-abdb42bb8a9a> in <module>()
      3 import jax.tools.colab_tpu
      4 jax.tools.colab_tpu.setup_tpu()
----> 5 jax.local_devices()

2 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/lib/xla_bridge.py in _get_backend_uncached(platform)
    246     if backend is None:
    247       if platform in _backends_errors:
--> 248         raise RuntimeError(f"Requested backend {platform}, but it failed "
    249                            f"to initialize: {_backends_errors[platform]}")
    250       raise RuntimeError(f"Unknown backend {platform}")

RuntimeError: Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED: Failed to connect to remote server at address: grpc://10.91.25.74:8470. Error from gRPC: Deadline Exceeded. Details: 

On the other hand, using an earlier version of driver in setup_tpu does work, e.g.,

def setup_tpu():
  """Sets up Colab to run on TPU.
  Note: make sure the Colab Runtime is set to Accelerator: TPU.
  """
  global TPU_DRIVER_MODE

  if not TPU_DRIVER_MODE:
    colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
    url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1-dev20211030'
    requests.post(url)
    TPU_DRIVER_MODE = 1

  # The following is required to use TPU Driver as JAX's backend.
  config.FLAGS.jax_xla_backend = "tpu_driver"
  config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 15

Most upvoted comments

@rog77 replace

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

with

import requests
import os

from jax.config import config

TPU_DRIVER_MODE = 0


def setup_tpu():
  """Sets up Colab to run on TPU.
  Note: make sure the Colab Runtime is set to Accelerator: TPU.
  """
  global TPU_DRIVER_MODE

  if not TPU_DRIVER_MODE:
    colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
    url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1-dev20211030'
    requests.post(url)
    TPU_DRIVER_MODE = 1

  # The following is required to use TPU Driver as JAX's backend.
  config.FLAGS.jax_xla_backend = "tpu_driver"
  config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

setup_tpu()

I am having the same issue.

import requests import jax import os import numpy as np

from jax.config import config

TPU_DRIVER_MODE = 0

def setup_tpu(): “”“Sets up Colab to run on TPU. Note: make sure the Colab Runtime is set to Accelerator: TPU. “”” global TPU_DRIVER_MODE

if not TPU_DRIVER_MODE: colab_tpu_addr = os.environ[‘COLAB_TPU_ADDR’].split(‘😂[0] url = f’http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1-dev20211030’ requests.post(url) TPU_DRIVER_MODE = 1

config.FLAGS.jax_xla_backend = “tpu_driver” config.FLAGS.jax_backend_target = “grpc://” + os.environ[‘COLAB_TPU_ADDR’]

setup_tpu()

jax.local_devices()

^^

This does indeed give me:

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

On Colab.

Thank you, @luyug!