transformers: Flax BERT finetuning notebook no longer works on TPUs
System Info
- Colab
transformers
version: 4.22.0.dev0- Platform: Linux-5.10.133±x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.13
- Huggingface_hub version: 0.9.1
- PyTorch version (GPU?): 1.12.1+cu113 (False)
- Tensorflow version (GPU?): 2.8.2 (False)
- Flax version (CPU?/GPU?/TPU?): 0.6.0 (cpu)
- Jax version: 0.3.17
- JaxLib version: 0.3.15
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: yes
- Using TPU: yes
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
The problem arises with the official notebook examples/text_classification_flax.ipynb.
The official notebook has some trivial problems (i.e., gradient_transformation
is never defined) which are fixed in this slightly modified version.
The notebook gets stuck on compiling at the training loop, and exits with this error:
Epoch ...: 0%
0/3 [00:00<?, ?it/s]
Training...: 0%
0/267 [00:00<?, ?it/s]
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-33-e147f5aff5fe> in <module>
5 with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
----> 6 for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
7 state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
17 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
XlaRuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/random.py in permutation(key, x, axis, independent)
413 raise TypeError("x must be an integer or at least 1-dimensional")
414 r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()')
--> 415 return _shuffle(key, jnp.arange(r), axis)
416 if independent or np.ndim(x) == 1:
417 return _shuffle(key, x, axis)
XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.
Expected behavior
The training is supposed to go smoothly. 😄
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 15 (5 by maintainers)
I am very happy that it worked @NightMachinery ! I think that it makes sense here to have a “reference” colab where people can refer to it - pinging @patil-suraj (for the fix I borrowed from the diffusers notebook) and @LysandreJik regarding the PR that you have suggested 😉 Thank you!
This works! I think the only difference with my previous code is supplying
tpu_driver_20221011
tosetup_tpu
. Where is that documented? I suggest having a central Colab TPU guide on HuggingFace docs which documents things like these that are necessary to run any TPU notebook.Do you want me to send a PR for this specific notebook?