transformers: When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem?
System Info
When I use the following code on tpuvm and use model.generate() to infer, the speed is very slow. It seems that the tpu is not used. What is the problem? jax device is exist
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
input_context = "The dog"
# encode input context
input_ids = tokenizer(input_context, return_tensors="np").input_ids
# generate candidates using sampling
outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
print(outputs)
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
input_context = "The dog"
# encode input context
input_ids = tokenizer(input_context, return_tensors="np").input_ids
# generate candidates using sampling
outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
print(outputs)
Expected behavior
Expect it to be fast
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 28 (10 by maintainers)
Hey @joytianya! Sorry about the late reply here! Cool to see that you’re using the Flax MT5 model!
The big speed-up from using JAX on TPU comes from JIT compiling a function: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html. It’s worth reading this guide to get a feel for how JAX + XLA + TPU work in combination to give you fast kernel execution.
I’ve written an ipynb notebook that demonstrates how you can JIT compile the generate method: https://github.com/sanchit-gandhi/codesnippets/blob/main/benchmark_flaxmt5_jit_generate.ipynb
Running this using a ‘tiny’ version of the Flax MT5 model on CPU, I get a 75x speed-up JIT compiling the generate function vs the vanilla generate function! That’s fast right!
You can adapt the script for the
mt5-smallcheckpoint as you require 🤗 You’ll need to pass any additional args that use boolean control flow in the generate method understatic_argnames(as done withmax_length,top_k,do_sample).Let me know if you have any other questions, happy to help!