transformers: When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem?

System Info

When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem? jax device is exist

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
input_context = "The dog"
# encode input context
input_ids = tokenizer(input_context, return_tensors="np").input_ids
# generate candidates using sampling
outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
print(outputs)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
input_context = "The dog"
# encode input context
input_ids = tokenizer(input_context, return_tensors="np").input_ids
# generate candidates using sampling
outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
print(outputs)

Expected behavior

Expect it to be fast

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 28 (10 by maintainers)

Most upvoted comments

Hey @joytianya! Sorry about the late reply here! Cool to see that you’re using the Flax MT5 model!

The big speed-up from using JAX on TPU comes from JIT compiling a function: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html. It’s worth reading this guide to get a feel for how JAX + XLA + TPU work in combination to give you fast kernel execution.

I’ve written an ipynb notebook that demonstrates how you can JIT compile the generate method: https://github.com/sanchit-gandhi/codesnippets/blob/main/benchmark_flaxmt5_jit_generate.ipynb

Running this using a ‘tiny’ version of the Flax MT5 model on CPU, I get a 75x speed-up JIT compiling the generate function vs the vanilla generate function! That’s fast right!

You can adapt the script for the mt5-small checkpoint as you require 🤗 You’ll need to pass any additional args that use boolean control flow in the generate method under static_argnames (as done with max_length, top_k, do_sample).

Let me know if you have any other questions, happy to help!