iree: Missing fusion opportunities in BERT attention layer

What happened?

Here is the IR for a typical attention layer of bert-base-uncased HuggingFace model with 1 sentence, sequence length 128: gist – see bottom for the version with full weights. (here is the program that generates it) if that is of interest.

There are some missing fusion opportunities marked in the graph below (generated with --iree-flow-dump-dispatch-graph for the IR). For this particular workload, getting these fusions should be about 33% speedup of this workload (the unfused code occupies 25% of the execution time). The main issue is that fusions are not happening into the outputs of the batch_matmul ops. There are missing biasadd fusions and an output transpose/reshape too. image

Steps to reproduce your issue

$ iree-compile --iree-hal-target-backends=cuda --iree-hal-cuda-llvm-target-arch=sm_80 /tmp/attention_layer.mlir

What component(s) does this issue relate to?

Compiler

Version information

python -m pip list| grep -E "iree|torch"
iree-compiler      20230213.429
iree-runtime       20230213.429
pytorch-triton     2.0.0+0d7e753227
torch              2.0.0.dev20230210+cu117
torch-mlir         20230211.746

Additional context

No response

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 46 (34 by maintainers)

Most upvoted comments

After folding the empty tensors(which holds the tiled output), the backend issue seems to be fixed. I compared the performance of matmul, fused transpose+matmul, and unfused transpose+matmul. The IR are:

matmul:

module {
  func.func @forward(%arg0: tensor<1x128x768xf32>, %arg1 : tensor<768x768xf32>) -> tensor<128x768xf32> {
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %2 = tensor.empty() : tensor<128x768xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x128x768xf32> into tensor<128x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<128x768xf32>) -> tensor<128x768xf32>
    %16 = linalg.matmul ins(%collapsed, %arg1 : tensor<128x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<128x768xf32>) -> tensor<128x768xf32>
    return %16 : tensor<128x768xf32>
}
}

fused transpose+matmul

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
module {
  func.func @forward(%arg0: tensor<1x128x768xf32>, %arg1 : tensor<768x768xf32>) -> tensor<128x768xf32> {
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %1 = tensor.empty() : tensor<768x768xf32>
    %2 = tensor.empty() : tensor<128x768xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x128x768xf32> into tensor<128x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<128x768xf32>) -> tensor<128x768xf32>
    %4 = flow.dispatch.region -> (tensor<128x768xf32>) {
      %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
      } -> tensor<768x768xf32>
      %16 = linalg.matmul ins(%collapsed, %15 : tensor<128x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<128x768xf32>) -> tensor<128x768xf32>
      flow.return %16 : tensor<128x768xf32>
    }
    return %4 : tensor<128x768xf32>
}
}

unfused transpose+matmul

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
module {
  func.func @forward(%arg0: tensor<1x128x768xf32>, %arg1 : tensor<768x768xf32>) -> tensor<128x768xf32> {
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant 8.000000e+00 : f32
    %1 = tensor.empty() : tensor<768x768xf32>
    %2 = tensor.empty() : tensor<128x768xf32>
    %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x128x768xf32> into tensor<128x768xf32>
    %3 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<128x768xf32>) -> tensor<128x768xf32>
    %15 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
      ^bb0(%in: f32, %out: f32):
        linalg.yield %in : f32
    } -> tensor<768x768xf32>
    %16 = linalg.matmul ins(%collapsed, %15 : tensor<128x768xf32>, tensor<768x768xf32>) outs(%3 : tensor<128x768xf32>) -> tensor<128x768xf32>
    return %16 : tensor<128x768xf32>
}
}

The performance number I got: matmul one: 7855 items_per_second=12.5886k/s fused transpose+matmul: 7923 items_per_second=12.7482k/s unfused transpose+matmul: 1280 items_per_second=1.8357k/s

Now the perf of fused one is as good as just doing matmul.

Please advise the next step. @MaheshRavishankar

Thanks @silvasean for all the insights on this specific BERT case, also adding we (at 🤗) are also very interested to contribute making these optimizations available. Just to give you some insights on our side too: we roughtly have ~30/40k “BERTish” models available on our hub, so it might benefit quite a lot of people along with BERT being our most popular model still 🙂.

From the script your shared to reproduce the attention layer, I wanted to point out you’re not using the attention_mask input variable to the model/attention: https://github.com/silvasean/iree-benchmarking/blob/main/attention_layer/attention_layer.py#L102

I the case of batch = 1 it doesn’t have any impact, for the other case (batch > 1) this is important to include such input as all the sequences in a batch will most likely not be of the same length (token-wise) and then you need to keep track of padding to avoid including padded tokens inside the attention softmax(s).

Here, in the script, the attention_mask will be filled from the BERT forward call with torch.ones which might be seen as a constant and trimmed out (constant + matmul with ones) (ref in transformers’ BERT model).

This also makes it possible to have “smarter” softmax implementations such as "masked_softmax(x, mask) which will avoid extra expensive exp calls (I saw multiple implementation of such in various HW compilers for instance).

FWIW, at 🤗 we have quite a lot customers of our inference products which are sending a single input (batch = 1), especially for generation task (GPT-like, Whisper, etc.) which sequentially generates multiple tokens, it is often a tricky challenge to batch these requests.

Concretely I think we have the following action items:

  1. The named ops are really biting us in terms of the embedding_size = num_heads * size_of_head dimension split and forcing reshape/transposes. We already solve this for elementwise ops by using higher-D representations and indexing map permutations to absorb reshape/transposes, and I think we need a similar solution here. I think Nicolas and Thomas are iterating on some tech that could potentially pave the way to reliable codegen of arbitrary linalg.generic that boils down to stamping out an inner matmul primitive. This is probably not a “1 week project” but I think if we put our heads together we can solve it (and probably boost our matmul performance along the way too).

    • Note: this also probably solves the other 2 issues and should improve our general einsum handling too.
    • Idea: Let’s form up a small “tiger team” for “generalized matmul codegen” with minimally Thomas, Mahesh, and Nicolas for guiding us on how to best apply the structured codegen tech to this problem. We will want to agree on 1) what is the canonical target-independent higher-D flow-level representation we want for this code 2) given that representation, how do we codegen it efficiently across targets.
  2. The broadcast fusion issue – we can’t currently even benchmark the num_sentences > 1 and get useful numbers due to the broadcast fusion issue (IREE’s performance is so bad it isn’t even meaningful to measure). I am rewriting my example in JAX and maybe we will get lucky that it won’t have such broadcasts. @powderluv – shouldn’t we have hit this big performance issue when doing HF transformers through Torch-MLIR with >1 sentences batched?

  3. I think we need to find a way to get the biasadd fusions. I ran my example with sequence length 512 instead of 128 and the biasadd and transpose fusions are still 15% of the runtime (25% in the original example). I will try to gather numbers for larger attention layers, where I suspect that the larger matmuls will dominate the time more, making the fusion less important.

    • That said, we are trying to build a state of the art ML compiler and biasadd fusion is table stakes. To be clear, the most straightforward PyTorch eager execution gets this biasadd fusion – currently other than removing dispatching overhead we are still sub-PyTorch-eager on this workload (more dispatches AND slower dispatches). With sequence length 512 we are spending 254us compute time and PyTorch eager is 183us (matmul compute time is 193us for IREE vs 144us for PyTorch, so a significant fraction is still non-matmul cost). Once we get into training and large inference workloads IREE’s dispatch efficiency won’t hide the underlying performance characteristics here.

It sounds like the higher-D matmul workstream I alluded to in point 1. would potentially solve all 3 issues directly and we should invest in that.

Looking at the links sent by Sean the transpose comes from the input, maybe Sean has a better understanding as I’m not sure why the batch size gets collapsed/expanded.

This is how it is written in the user’s source code HuggingFace example 1, HuggingFace example 2. It is also written as such in the HuggingFace JAX version. Also it is similarly written in nanoGPT.

This reshape/transpose fundamental to how transformers work. I will provide some context into how the self-attention layer works for intuition (a simple intro with diagrams is here).

To summarize the discussion below, there are 5 matmuls with various reshapes, forming a pattern something like matvecish(softmax_rows(outer_productish(linear0(x), linear1(x))), linear2(x)). The “matvec”/“outer product” are “ish” because instead of the special case of a dimension of size 1, they have that dimension of size 64 or 128, which for decently sized workloads is very small compared to the other dimensions. The matvecish(softmax_rows(outer_productish(a, b)), c) computation is called the “scaled dot product attention” computation (e.g. what is done by FlashAttention to avoid materializing the full outer product matrix in memory). The linear[0-2] are standard matmul-biasadd layers.

  • 3x batch matmul [B=num_sentences, M=sequence_length, K=embedding_size=num_heads * size_of_head] * [K=embedding_size, N=embedding_size] -> [B=num_sentences, M=sequence_length, N=embedding_size=num_heads * size_of_head].
  • 3x reshape/transpose to split embedding_size into num_heads and size_of_head and move num_heads into the batch dimensions. E.g. [num_sentences, sequence_length, embedding_size=num_heads * size_of_head] -> [num_sentences * num_heads, sequence_length, size_of_head]
  • “outer product-ish” batch matmul [B=num_sentences*num_heads, M=sequence_length, K=size_of_head] * [B=num_sentences*num_heads, N=sequence_length, K=size_of_head] -> [B=num_sentences*num_heads, M=sequence_length, N=sequence_length]. softmax along the M dimension of this. (note that this output is size $O(\texttt{sequence\_length}^2)$ )
  • “matvec-ish” batch matmul [B=num_sentences*num_heads, M=sequence_length, K=sequence_length] * [B=num_sentences*num_heads, N=size_of_head, K=sequence_length] -> [B=num_sentences*num_heads, M=sequence_length, N=size_of_head]
  • Finally this is transpose/reshaped into [num_sentences, sequence_length, embedding_size=num_heads * size_of_head].

Summarizing the below discussion on sizes, we are talking about

  • num_sentences = 16 to 64 (this is a pure batch dimension) (correction: this used to have an upper bound of 1000, but that is the global, rather than minibatch size when training)
  • sequence_length = 128 to 2048
  • embedding_size = num_heads * size_of_head = 768 to 12288, with size_of_head = 64 to 128 (aligned to 16 at least in GPT-3)
  • They are usually chosen as nice aligned sizes.

The embedding_size is split into what are called “heads” (this gives it the name “multi-headed attention”). embedding_size = num_heads * size_of_headsize_of_head is relatively constant, with size 64 to 128 over 3 orders of magnitude of model size. The transpose/reshape’s come from the fact that the first 3 matmuls contract over embedding_size, and the other two treat num_heads as a batch dimension which requires reshape/transpose (since we are calling into named ops with fixed dimension orderings). If we used linalg.generic ops with reduction dimensions to model higher-d matmul-like ops (incorporating transposes and multiple reduction-dims), then we could absorb all the transposes/reshapes quite cleanly since we could properly model embedding_size as two dimensions num_heads and size_of_head.

For context on the actual magnitude of the numbers, Transformer models are known to be relatively insensitive to particular choices of many of the hyperparameters for a given compute/parameter budget and there is a known “sweet spot” for langauge models. So there is actually a pretty small space of meaningful parameters for transformer-based language models since many of the hyperparameters are related. As such I will use the sizes from the GPT-3 paper as representative (especially Table 2.1).

The input to the self-attention layer consists of vectors of a particular embedding size (embedding_size, or $d_{model}$ in Table 2.1). For the 125M parameter “GPT-3 Small” this size is 768 the same as the example IR I gave which was for bert-base-uncased (110M parameters, see how the sizes and choices are very similar between GPT-3 and BERT). For the 175B parameter “GPT-3” this size is 12288 (note: about 1000x more params, but this dimension only increases by 16x). There are [number of sentences, sentence length in tokens] of them. This can vary between training and inference but the sequence length for all GPT-3-family models is 2048, while for BERT it is often 512, and the batch size in a datacenter inference scenario might be 16 or 32 more-or-less depending on latency budget, while for training a GPT-3-like model it will be on the order of 16-64. So we are talking about inputs of size [num_sentences, sequence_length, embedding_size] = [16-1000ish, 128-2048, 768-12288]. They are usually chosen as nice aligned sizes.

Fusing broadcasts with the matmul/batch_matmul layer is definitely worth it. I havent looked at this in a bit, but I think I explicitly disallow this fusion cause it is a matter of the backends being able to handle it. Based on my understanding on the CPU backends, it shouldnt be too hard to support this. For CUDA backends since operands are promoted to shared memory, this is one place where the broadcast of the LHS should be managed.

FWIW, in Torch-MLIR we had this kind of gridlock requiring all backends to support something before one backend could make progress, and it was quite problematic. We solved that in Torch-MLIR by allowing backends some limited control over the target-independent lowering process. Over time we are going to grow more and more backends with different sets of maintainers and it is going to be increasingly problematic to atomically update all of them in order to make progress.

As always the case with such programs it is useful to take a full view of the program, and not zoom into small portions (before trying to devise a strategy). The issue here starts with

#map2 = affine_map<(d0, d1, d2) -> (0, d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d1, d2)>
    %2 = tensor.empty() : tensor<1x12x768xf32>
    %3 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x12x768xf32>) outs(%2 : tensor<1x12x768xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x12x768xf32>
    %4 = tensor.empty() : tensor<1x768x768xf32>
    %5 = linalg.generic {indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1 : tensor<768x768xf32>) outs(%4 : tensor<1x768x768xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x768x768xf32>
    %6 = linalg.fill ins(%cst_6 : f32) outs(%2 : tensor<1x12x768xf32>) -> tensor<1x12x768xf32>
    %7 = linalg.batch_matmul ins(%3, %5 : tensor<1x12x768xf32>, tensor<1x768x768xf32>) outs(%6 : tensor<1x12x768xf32>) -> tensor<1x12x768xf32>
    %8 = linalg.generic {indexing_maps = [#map2, #map5, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %cst_5 : tensor<1x12x768xf32>, tensor<768xf32>) outs(%2 : tensor<1x12x768xf32>) {
    ^bb0(%in: f32, %in_16: f32, %out: f32):
      %46 = arith.addf %in, %in_16 : f32
      linalg.yield %46 : f32
    } -> tensor<1x12x768xf32>

Two things to look at. The batch_matmul is actually just a matmul. There is a lot of “broadcast-like” operations here that are actually just copies if you drop the unit-dimensions. #map2 = affine_map<(d0, d1, d2) -> (0, d1, d2)> are problematic. They lead to lots of corner cases, and make analysis of indexing maps unnecessarily complicated. There is an attempt in IREE (built in MLIR) to get rid of such spurious unit-dimensions. If I do a “cleanup” of the graph to replace batch_matmul with matmul and remove the use of #map2 to a more canonical representation , i.e. start with this input

func.func @forward(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %cst = arith.constant -3.40282347E+38 : f32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %cst_1 = arith.constant dense_resource<__elided__> : tensor<768xf32>
  %cst_2 = arith.constant dense_resource<__elided__> : tensor<768x768xf32>
  %cst_3 = arith.constant dense<8.000000e+00> : tensor<f32>
  %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<1x12x768xf32>
  %1 = tensor.empty() : tensor<768x768xf32>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%cst_2 : tensor<768x768xf32>) outs(%1 : tensor<768x768xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<768x768xf32>
  %3 = tensor.empty() : tensor<12x768xf32>
  %collapsed = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<1x12x768xf32> into tensor<12x768xf32>
  %8 = linalg.fill ins(%cst_0 : f32) outs(%3 : tensor<12x768xf32>) -> tensor<12x768xf32>
  %9 = linalg.matmul ins(%collapsed, %2 : tensor<12x768xf32>, tensor<768x768xf32>) outs(%8 : tensor<12x768xf32>) -> tensor<12x768xf32>
  %10 = tensor.empty() : tensor<12x768xf32>
  %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %cst_1 : tensor<12x768xf32>, tensor<768xf32>) outs(%10 : tensor<12x768xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.addf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x768xf32>
  %expanded_6 = tensor.expand_shape %11 [[0], [1, 2]] : tensor<12x768xf32> into tensor<12x12x64xf32>
  %12 = tensor.empty() : tensor<12x12x64xf32>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_6 : tensor<12x12x64xf32>) outs(%12 : tensor<12x12x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x64xf32>
  %14 = tensor.empty() : tensor<12x64x12xf32>
  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<12x12x64xf32>) outs(%14 : tensor<12x64x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x64x12xf32>
  %16 = tensor.empty() : tensor<12x12x64xf32>
  %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<12x12x64xf32>) outs(%16 : tensor<12x12x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x64xf32>
  %18 = tensor.empty() : tensor<12x64x12xf32>
  %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15 : tensor<12x64x12xf32>) outs(%18 : tensor<12x64x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x64x12xf32>
  %20 = tensor.empty() : tensor<12x12x12xf32>
  %21 = linalg.fill ins(%cst_0 : f32) outs(%20 : tensor<12x12x12xf32>) -> tensor<12x12x12xf32>
  %22 = linalg.batch_matmul ins(%17, %19 : tensor<12x12x64xf32>, tensor<12x64x12xf32>) outs(%21 : tensor<12x12x12xf32>) -> tensor<12x12x12xf32>
  %23 = tensor.empty() : tensor<12x12x12xf32>
  %24 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%22, %cst_3 : tensor<12x12x12xf32>, tensor<f32>) outs(%23 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.divf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %25 = tensor.empty() : tensor<12x12xf32>
  %26 = linalg.fill ins(%cst : f32) outs(%25 : tensor<12x12xf32>) -> tensor<12x12xf32>
  %27 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%24 : tensor<12x12x12xf32>) outs(%26 : tensor<12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    %45 = arith.maxf %in, %out : f32
    linalg.yield %45 : f32
  } -> tensor<12x12xf32>
  %28 = tensor.empty() : tensor<12x12x12xf32>
  %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%24, %27 : tensor<12x12x12xf32>, tensor<12x12xf32>) outs(%28 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.subf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %30 = tensor.empty() : tensor<12x12x12xf32>
  %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%29 : tensor<12x12x12xf32>) outs(%30 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    %45 = math.exp %in : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %32 = tensor.empty() : tensor<12x12xf32>
  %33 = linalg.fill ins(%cst_0 : f32) outs(%32 : tensor<12x12xf32>) -> tensor<12x12xf32>
  %34 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%31 : tensor<12x12x12xf32>) outs(%33 : tensor<12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    %45 = arith.addf %in, %out : f32
    linalg.yield %45 : f32
  } -> tensor<12x12xf32>
  %35 = tensor.empty() : tensor<12x12x12xf32>
  %36 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%31, %34 : tensor<12x12x12xf32>, tensor<12x12xf32>) outs(%35 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %in_9: f32, %out: f32):
    %45 = arith.divf %in, %in_9 : f32
    linalg.yield %45 : f32
  } -> tensor<12x12x12xf32>
  %37 = tensor.empty() : tensor<12x12x12xf32>
  %38 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%36 : tensor<12x12x12xf32>) outs(%37 : tensor<12x12x12xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x12xf32>
  %39 = tensor.empty() : tensor<12x12x64xf32>
  %40 = linalg.fill ins(%cst_0 : f32) outs(%39 : tensor<12x12x64xf32>) -> tensor<12x12x64xf32>
  %41 = linalg.batch_matmul ins(%38, %17 : tensor<12x12x12xf32>, tensor<12x12x64xf32>) outs(%40 : tensor<12x12x64xf32>) -> tensor<12x12x64xf32>
  %42 = tensor.empty() : tensor<12x12x64xf32>
  %43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%41 : tensor<12x12x64xf32>) outs(%42 : tensor<12x12x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<12x12x64xf32>
  %expanded_7 = tensor.expand_shape %43 [[0, 1], [2], [3]] : tensor<12x12x64xf32> into tensor<1x12x12x64xf32>
  %collapsed_8 = tensor.collapse_shape %expanded_7 [[0], [1], [2, 3]] : tensor<1x12x12x64xf32> into tensor<1x12x768xf32>
  %44 = hal.tensor.export %collapsed_8 : tensor<1x12x768xf32> -> !hal.buffer_view
  return %44 : !hal.buffer_view
}

Then I end up with 7 dispatches, instead of 13 as was the case in the bug above… Here is the IR after some relevant passes https://gist.github.com/MaheshRavishankar/11cf9d4ff1beca0d004551f7f7d807b2

Looking at the remaining dispatches, there is a transpose at the very beginning. We can discuss if we want to fuse it with the next dispatch that is doing a matmul. That is a possibility, but as I said above, a better metric to go for is to make the matmul get the better layout and then try to fuse the relayout operation with other operations (instead of indexing on the transpose and trying to fuse that with a matmul. The goal should be to get the matmul in the right layout).

The other remaining inefficiency are two dispatches with a single generic op. These are doing two different transposes. That might be worth looking into where those transposes are coming from and then trying to fix them at the right place. My point is, when it comes to fusion its useful to look at the entire IR from input until the dispatch region formation and try to “clean up” the IR into something that will require very less actual rules within the fusion itself…