transformers: TF2 DeBERTaV2 runs super slow on TPUs

System Info

latest version of transformers, Colab TPU, tensorflow 2

Who can help?

@kamalkraj @Rocketknight1 @BigBird01

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

It’s currently hard to share code and access to the google bucket. But I believe any TF2 DeBERTaV2 code running on TPUs will have this issue

Expected behavior

I’ve been trying to train a deberta v3 model on GPU and TPUs. I got it to work on multi-node and multi-gpus using Nvidia deeplearning examples libraries https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/ I basically used the training setup and loop from the BERT code, the dataset utils from the ELECTRA code, and the model from Huggingface transformers with some changes in order to share embeddings.

On 6xA40 45gb gpus i get around 1370 sentences per seconds during training (which is lower than what Nvidia gets for Electra but it’s fine).

Ok, now the problem… on TPU i get 20 sentences per second

I traced the issue back to the tf.gather function here https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525

I ran TPU profiling and this is the output: image

GatherV2 takes most of the time: image

zoomed in pictures of the fast ops image

Also, I’m not sure if this is TPU specific since on GPUs the training ~30% slower compared to regular ELECTRA.

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 34 (31 by maintainers)

Most upvoted comments

@sanchit-gandhi do you know a good point of contact for TPU problems?

Only for JAX on TPU, I’ll ask around and see if there is anyone who can help with TF!

For JAX BLOOM we couldn’t even compile the 176B parameter model with the naive implementation of concatenate_to_cache, yet alone benchmark which operations consumed the bulk of the execution time! We swapped it for this more efficient implementation (with one-hot encodings etc): https://github.com/huggingface/bloom-jax-inference/blob/2a04aa519d262729d54adef3d19d63879f81ea89/bloom_inference/modeling_bloom/modeling_bloom.py#L119 Coincidentally, we’ve just run the JAX profiler for this implementation and are going through the traceback it with some of the Google JAX guys later today. Will report back on how performance fares!

I tried disabling relative_attention in deberta, which makes the model a regular BERT, and the performance improved 40x 😅

@WissamAntoun Confirmed reproduction of the issue here. Our TF DeBERTa implementation seems to have issues with XLA - I’m investigating now.

Also cc @sanchit-gandhi because I’m not a TPU expert - don’t worry about investigating this deeply, but if anything comes to mind when you read it, let me know!