transformers: Flax BERT finetuning notebook no longer works on TPUs

System Info

  • Colab
  • transformers version: 4.22.0.dev0
  • Platform: Linux-5.10.133±x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.13
  • Huggingface_hub version: 0.9.1
  • PyTorch version (GPU?): 1.12.1+cu113 (False)
  • Tensorflow version (GPU?): 2.8.2 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.6.0 (cpu)
  • Jax version: 0.3.17
  • JaxLib version: 0.3.15
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: yes
  • Using TPU: yes

Who can help?

@patil-suraj @LysandreJik

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

The problem arises with the official notebook examples/text_classification_flax.ipynb.

The official notebook has some trivial problems (i.e., gradient_transformation is never defined) which are fixed in this slightly modified version.

The notebook gets stuck on compiling at the training loop, and exits with this error:

Epoch ...: 0%
0/3 [00:00<?, ?it/s]
Training...: 0%
0/267 [00:00<?, ?it/s]
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-33-e147f5aff5fe> in <module>
      5     with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
----> 6       for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
      7         state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)

17 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

XlaRuntimeError                           Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/random.py in permutation(key, x, axis, independent)
    413       raise TypeError("x must be an integer or at least 1-dimensional")
    414     r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()')
--> 415     return _shuffle(key, jnp.arange(r), axis)
    416   if independent or np.ndim(x) == 1:
    417     return _shuffle(key, x, axis)

XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.

Expected behavior

The training is supposed to go smoothly. 😄

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 15 (5 by maintainers)

Most upvoted comments

I am very happy that it worked @NightMachinery ! I think that it makes sense here to have a “reference” colab where people can refer to it - pinging @patil-suraj (for the fix I borrowed from the diffusers notebook) and @LysandreJik regarding the PR that you have suggested 😉 Thank you!

Hey @NightMachinery ! Can you try with these cells for installation? I think that I gave you the wrong installation guidelines before

#@title Set up JAX
#@markdown If you see an error, make sure you are using a TPU backend. Select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting.
!pip install --upgrade jax jaxlib 

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')

!pip install flax diffusers transformers ftfy
jax.devices()

I can confirm jax_devices() gave me


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

This is based on the recent demo from diffusers, see the colab here: colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fast_jax.ipynb

This works! I think the only difference with my previous code is supplying tpu_driver_20221011 to setup_tpu. Where is that documented? I suggest having a central Colab TPU guide on HuggingFace docs which documents things like these that are necessary to run any TPU notebook.

Do you want me to send a PR for this specific notebook?