jax: TPU not detected by jax in Colab

I am attempting to use the Google provided notebook for Reformer with TPU and I get

usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:118: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

It works fine with GPU on the other hand.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 4
  • Comments: 21 (8 by maintainers)

Most upvoted comments

Those two lines made it work for me

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

I’ve not run Reformer but this may work: try running the following cell first, so that the Colab runtime is set to TPU acceleration:

# get the latest JAX and jaxlib
!pip install --upgrade -q jax jaxlib

# Colab runtime set to TPU accel
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# TPU driver as backend for JAX
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

Source: Cloud TPU NeurIPS 2019 Colab.

The notebook was made by @skye and @mattjj, I think, and demoed by @skye at NeurIPS, when TPU-support was unveiled, I think. (Yours truly was lucky to have attended that demo).

@nkitaev Yes, I ran it to completion and I’m getting good results once I account for the compiling. I played with it some to get a better idea and I’m trying it on a bigger dataset next (wikipedia) since that’s where it will shine most.

The only thing I’m unsure of is if I can use the same config to work with TPUs on gcloud (I haven’t tried yet), given that the settings seem at least partially Colab-specific.

As for this issue - it’s basically resolved for me and can be closed since I have it working. I assume there are plans already to either include the relevant config steps in the docs or to detect the tpu directly in xla_bridge.py anyway.

@Tenoke if you’re not getting that “No GPU/TPU found” warning, it should be running on TPU (as another way to check, jax.devices() should return 8 TpuDevice objects). Like @nkitaev says, TPUs really shine with large inputs. If you’re only performing tiny quick computations, non-TPU overheads will dominate the overall time and you won’t see a benefit from hardware acceleration. If you try the demo colab, the microbenchmark at the end of “The basics: interactive NumPy on GPU and TPU” section should show a difference between TPU and CPU. (Also that tpu_client import comes from jaxlib here).

Thanks for reporting this btw. Even if turns out that everything’s working as expected, this is still useful feedback on the initial experience. We’re still working to make JAX + Cloud TPU better!

@8bitmp3 you need to add %%time not %time. Single % only evaluates the line it’s on (which is empty) which is why you get µs.

I’m not sure why the Step 1500 time is way lower on CPU than on TPU though.

I’m rerunning them now.

Update: just ran the Trax Transformer/Reformer Colab after adding and running an extra cell with the above-mentioned code and it worked.

(Don’t forget to set the accelerator to TPU under Edit > Notebook settings).