iree: Missing fusion opportunities in BERT attention layer
What happened?
Here is the IR for a typical attention layer of bert-base-uncased HuggingFace model with 1 sentence, sequence length 128: gist – see bottom for the version with full weights. (here is the program that generates it) if that is of interest.
There are some missing fusion opportunities marked in the graph below (generated with --iree-flow-dump-dispatch-graph for the IR). For this particular workload, getting these fusions should be about 33% speedup of this workload (the unfused code occupies 25% of the execution time). The main issue is that fusions are not happening into the outputs of the batch_matmul ops. There are missing biasadd fusions and an output transpose/reshape too.
Steps to reproduce your issue
$ iree-compile --iree-hal-target-backends=cuda --iree-hal-cuda-llvm-target-arch=sm_80 /tmp/attention_layer.mlir
What component(s) does this issue relate to?
Compiler
Version information
python -m pip list| grep -E "iree|torch"
iree-compiler 20230213.429
iree-runtime 20230213.429
pytorch-triton 2.0.0+0d7e753227
torch 2.0.0.dev20230210+cu117
torch-mlir 20230211.746
Additional context
No response
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 46 (34 by maintainers)
After folding the empty tensors(which holds the tiled output), the backend issue seems to be fixed. I compared the performance of matmul, fused transpose+matmul, and unfused transpose+matmul. The IR are:
matmul:
fused transpose+matmul
unfused transpose+matmul
The performance number I got: matmul one: 7855 items_per_second=12.5886k/s fused transpose+matmul: 7923 items_per_second=12.7482k/s unfused transpose+matmul: 1280 items_per_second=1.8357k/s
Now the perf of fused one is as good as just doing matmul.
Please advise the next step. @MaheshRavishankar
Thanks @silvasean for all the insights on this specific BERT case, also adding we (at 🤗) are also very interested to contribute making these optimizations available. Just to give you some insights on our side too: we roughtly have ~30/40k “BERTish” models available on our hub, so it might benefit quite a lot of people along with BERT being our most popular model still 🙂.
From the script your shared to reproduce the attention layer, I wanted to point out you’re not using the
attention_mask
input variable to the model/attention: https://github.com/silvasean/iree-benchmarking/blob/main/attention_layer/attention_layer.py#L102I the case of
batch = 1
it doesn’t have any impact, for the other case (batch > 1
) this is important to include such input as all the sequences in a batch will most likely not be of the same length (token-wise) and then you need to keep track of padding to avoid including padded tokens inside the attention softmax(s).Here, in the script, the
attention_mask
will be filled from the BERT forward call withtorch.ones
which might be seen as a constant and trimmed out (constant + matmul with ones) (ref in transformers’ BERT model).This also makes it possible to have “smarter” softmax implementations such as "
masked_softmax(x, mask)
which will avoid extra expensiveexp
calls (I saw multiple implementation of such in various HW compilers for instance).FWIW, at 🤗 we have quite a lot customers of our inference products which are sending a single input (
batch = 1
), especially for generation task (GPT-like, Whisper, etc.) which sequentially generates multiple tokens, it is often a tricky challenge to batch these requests.Concretely I think we have the following action items:
The named ops are really biting us in terms of the
embedding_size = num_heads * size_of_head
dimension split and forcing reshape/transposes. We already solve this for elementwise ops by using higher-D representations and indexing map permutations to absorb reshape/transposes, and I think we need a similar solution here. I think Nicolas and Thomas are iterating on some tech that could potentially pave the way to reliable codegen of arbitrary linalg.generic that boils down to stamping out an inner matmul primitive. This is probably not a “1 week project” but I think if we put our heads together we can solve it (and probably boost our matmul performance along the way too).The broadcast fusion issue – we can’t currently even benchmark the num_sentences > 1 and get useful numbers due to the broadcast fusion issue (IREE’s performance is so bad it isn’t even meaningful to measure). I am rewriting my example in JAX and maybe we will get lucky that it won’t have such broadcasts. @powderluv – shouldn’t we have hit this big performance issue when doing HF transformers through Torch-MLIR with >1 sentences batched?
I think we need to find a way to get the biasadd fusions. I ran my example with sequence length 512 instead of 128 and the biasadd and transpose fusions are still 15% of the runtime (25% in the original example). I will try to gather numbers for larger attention layers, where I suspect that the larger matmuls will dominate the time more, making the fusion less important.
It sounds like the higher-D matmul workstream I alluded to in point 1. would potentially solve all 3 issues directly and we should invest in that.
This is how it is written in the user’s source code HuggingFace example 1, HuggingFace example 2. It is also written as such in the HuggingFace JAX version. Also it is similarly written in nanoGPT.
This reshape/transpose fundamental to how transformers work. I will provide some context into how the self-attention layer works for intuition (a simple intro with diagrams is here).
To summarize the discussion below, there are 5 matmuls with various reshapes, forming a pattern something like
matvecish(softmax_rows(outer_productish(linear0(x), linear1(x))), linear2(x))
. The “matvec”/“outer product” are “ish” because instead of the special case of a dimension of size 1, they have that dimension of size 64 or 128, which for decently sized workloads is very small compared to the other dimensions. Thematvecish(softmax_rows(outer_productish(a, b)), c)
computation is called the “scaled dot product attention” computation (e.g. what is done by FlashAttention to avoid materializing the full outer product matrix in memory). Thelinear[0-2]
are standard matmul-biasadd layers.[B=num_sentences, M=sequence_length, K=embedding_size=num_heads * size_of_head] * [K=embedding_size, N=embedding_size] -> [B=num_sentences, M=sequence_length, N=embedding_size=num_heads * size_of_head]
.embedding_size
intonum_heads
andsize_of_head
and movenum_heads
into the batch dimensions. E.g.[num_sentences, sequence_length, embedding_size=num_heads * size_of_head] -> [num_sentences * num_heads, sequence_length, size_of_head]
[B=num_sentences*num_heads, M=sequence_length, K=size_of_head] * [B=num_sentences*num_heads, N=sequence_length, K=size_of_head] -> [B=num_sentences*num_heads, M=sequence_length, N=sequence_length]
. softmax along the M dimension of this. (note that this output is size $O(\texttt{sequence\_length}^2)$ )[B=num_sentences*num_heads, M=sequence_length, K=sequence_length] * [B=num_sentences*num_heads, N=size_of_head, K=sequence_length] -> [B=num_sentences*num_heads, M=sequence_length, N=size_of_head]
[num_sentences, sequence_length, embedding_size=num_heads * size_of_head]
.Summarizing the below discussion on sizes, we are talking about
num_sentences = 16 to 64
(this is a pure batch dimension) (correction: this used to have an upper bound of 1000, but that is the global, rather than minibatch size when training)sequence_length = 128 to 2048
embedding_size = num_heads * size_of_head = 768 to 12288
, withsize_of_head = 64 to 128
(aligned to 16 at least in GPT-3)The
embedding_size
is split into what are called “heads” (this gives it the name “multi-headed attention”).embedding_size = num_heads * size_of_head
–size_of_head
is relatively constant, with size 64 to 128 over 3 orders of magnitude of model size. The transpose/reshape’s come from the fact that the first 3 matmuls contract overembedding_size
, and the other two treatnum_heads
as a batch dimension which requires reshape/transpose (since we are calling into named ops with fixed dimension orderings). If we used linalg.generic ops with reduction dimensions to model higher-d matmul-like ops (incorporating transposes and multiple reduction-dims), then we could absorb all the transposes/reshapes quite cleanly since we could properly modelembedding_size
as two dimensionsnum_heads
andsize_of_head
.For context on the actual magnitude of the numbers, Transformer models are known to be relatively insensitive to particular choices of many of the hyperparameters for a given compute/parameter budget and there is a known “sweet spot” for langauge models. So there is actually a pretty small space of meaningful parameters for transformer-based language models since many of the hyperparameters are related. As such I will use the sizes from the GPT-3 paper as representative (especially Table 2.1).
The input to the self-attention layer consists of vectors of a particular embedding size (embedding_size, or $d_{model}$ in Table 2.1). For the 125M parameter “GPT-3 Small” this size is 768 the same as the example IR I gave which was for bert-base-uncased (110M parameters, see how the sizes and choices are very similar between GPT-3 and BERT). For the 175B parameter “GPT-3” this size is 12288 (note: about 1000x more params, but this dimension only increases by 16x). There are [number of sentences, sentence length in tokens] of them. This can vary between training and inference but the sequence length for all GPT-3-family models is 2048, while for BERT it is often 512, and the batch size in a datacenter inference scenario might be 16 or 32 more-or-less depending on latency budget, while for training a GPT-3-like model it will be on the order of 16-64. So we are talking about inputs of size [num_sentences, sequence_length, embedding_size] = [16-1000ish, 128-2048, 768-12288]. They are usually chosen as nice aligned sizes.
FWIW, in Torch-MLIR we had this kind of gridlock requiring all backends to support something before one backend could make progress, and it was quite problematic. We solved that in Torch-MLIR by allowing backends some limited control over the target-independent lowering process. Over time we are going to grow more and more backends with different sets of maintainers and it is going to be increasingly problematic to atomically update all of them in order to make progress.
As always the case with such programs it is useful to take a full view of the program, and not zoom into small portions (before trying to devise a strategy). The issue here starts with
Two things to look at. The
batch_matmul
is actually just a matmul. There is a lot of “broadcast-like” operations here that are actually just copies if you drop the unit-dimensions.#map2 = affine_map<(d0, d1, d2) -> (0, d1, d2)>
are problematic. They lead to lots of corner cases, and make analysis of indexing maps unnecessarily complicated. There is an attempt in IREE (built in MLIR) to get rid of such spurious unit-dimensions. If I do a “cleanup” of the graph to replacebatch_matmul
withmatmul
and remove the use of#map2
to a more canonical representation , i.e. start with this inputThen I end up with 7 dispatches, instead of 13 as was the case in the bug above… Here is the IR after some relevant passes https://gist.github.com/MaheshRavishankar/11cf9d4ff1beca0d004551f7f7d807b2
Looking at the remaining dispatches, there is a transpose at the very beginning. We can discuss if we want to fuse it with the next dispatch that is doing a matmul. That is a possibility, but as I said above, a better metric to go for is to make the matmul get the better layout and then try to fuse the relayout operation with other operations (instead of indexing on the transpose and trying to fuse that with a matmul. The goal should be to get the matmul in the right layout).
The other remaining inefficiency are two dispatches with a single generic op. These are doing two different transposes. That might be worth looking into where those transposes are coming from and then trying to fix them at the right place. My point is, when it comes to fusion its useful to look at the entire IR from input until the dispatch region formation and try to “clean up” the IR into something that will require very less actual rules within the fusion itself…