iree: Offset computation seems to introduce regressions on CUDA backends

PR #12874 introduces a consistent handling of subspan with offsets on all backends. This avoids relying on using GEPs to make the offsets work, since backends like SPIR-V and VMVX do not have GEP-like instructions.

Unfortunately this seems to introduce regressions on CUDA backend. Here is the IR after certain passes on Clips64 model before the change https://gist.github.com/MaheshRavishankar/7996d674d878c2154d74d97e843018a1 Here is the IR after the change https://gist.github.com/MaheshRavishankar/ee1d51447763d5fd8ee16f855b7bcc81

The issue here seems to be this kind of computation

 ^bb1:  // pred: ^bb0
    %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c993664) flags(ReadOnly) : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
    memref.assume_alignment %5, 64 : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
    %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c1230208) flags(ReadOnly) : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
    memref.assume_alignment %6, 64 : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
    %7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c946176) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 236544>>
    memref.assume_alignment %7, 64 : memref<77x77xf32, strided<[77, 1], offset: 236544>>
    %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c969920) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 242480>>
    memref.assume_alignment %8, 64 : memref<77x77xf32, strided<[77, 1], offset: 242480>>
    %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c1466752) : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
    memref.assume_alignment %9, 64 : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
    %workgroup_id_z = hal.interface.workgroup.id[2] : index
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %10 = gpu.thread_id  x
    %11 = gpu.thread_id  y
    %12 = arith.muli %10, %c-3 : index
    %13 = arith.addi %12, %c77 : index
    %14 = arith.cmpi slt, %13, %c3 : index
    %15 = arith.select %14, %13, %c3 : index
    %16 = arith.cmpi slt, %15, %c0 : index
    %17 = arith.select %16, %c0, %15 : index
    %18 = arith.muli %11, %c4 : index
    %19 = arith.muli %workgroup_id_y, %c32 : index
    %20 = arith.addi %18, %19 : index
    %21 = arith.muli %10, %c3 : index
    %22 = arith.muli %workgroup_id_x, %c128 : index
    %23 = arith.addi %21, %22 : index
    %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %9 : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>> -> memref<f32>, index, index, index, index, index, index, index
    %24 = arith.muli %workgroup_id_z, %c5929 : index
    %25 = arith.addi %24, %c366688 : index
    %26 = arith.muli %20, %c77 : index
    %27 = arith.addi %25, %26 : index
    %28 = arith.addi %27, %23 : index
   cf.br ^bb2(%c0 : index)
  ^bb2(%29: index):  // 2 preds: ^bb1, ^bb6
    %30 = arith.cmpi slt, %29, %c4 : index
    cf.cond_br %30, ^bb3, ^bb7
  ^bb3:  // pred: ^bb2
    %31 = arith.muli %29, %c77 : index
    %32 = arith.addi %28, %31 : index
    cf.br ^bb4(%c0 : index)
  ^bb4(%33: index):  // 2 preds: ^bb3, ^bb5
    %34 = arith.cmpi slt, %33, %17 : index
    cf.cond_br %34, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    %35 = arith.addi %32, %33 : index
    %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [%35], sizes: [1, 1, 1], strides: [5929, 77, 1] : memref<f32> to memref<1x1x1xf32, strided<[5929, 77, 1], offset: ?>>
    memref.store %cst, %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x1xf32, strided<[5929, 77, 1], offset: ?>>
    %36 = arith.addi %33, %c1 : index
    cf.br ^bb4(%36 : index)
  ^bb6:  // pred: ^bb4
    %37 = arith.addi %29, %c1 : index
    cf.br ^bb2(%37 : index)
  ^bb7:  // pred: ^bb2
    %38 = gpu.thread_id  z
    %39 = arith.muli %10, %c4 : index
    %base_buffer_1, %offset_2, %sizes_3:3, %strides_4:3 = memref.extract_strided_metadata %5 : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>> -> memref<f32>, index, index, index, index, index, index, index
    %base_buffer_5, %offset_6, %sizes_7:3, %strides_8:3 = memref.extract_strided_metadata %6 : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>> -> memref<f32>, index, index, index, index, index, index, index
    %40 = arith.muli %18, %c36 : index
    cf.br ^bb8(%c0 : index)
  ^bb8(%41: index):  // 2 preds: ^bb7, ^bb41
    %42 = arith.cmpi slt, %41, %c64 : index
    cf.cond_br %42, ^bb9, ^bb42
  ^bb41:  // pred: ^bb33
    %123 = arith.addi %41, %c32 : index
    cf.br ^bb8(%123 : index)
 ^bb42:  // pred: ^bb8
    %base_buffer_16, %offset_17, %sizes_18:2, %strides_19:2 = memref.extract_strided_metadata %7 : memref<77x77xf32, strided<[77, 1], offset: 236544>> -> memref<f32>, index, index, index, index, index
    %base_buffer_20, %offset_21, %sizes_22:2, %strides_23:2 = memref.extract_strided_metadata %8 : memref<77x77xf32, strided<[77, 1], offset: 242480>> -> memref<f32>, index, index, index, index, index
    %124 = arith.addi %26, %c236544 : index
    %125 = arith.addi %124, %23 : index
    %126 = arith.addi %26, %c242480 : index
    %127 = arith.addi %126, %23 : index

The last 4 instructions are mostly hoistable (if not all of them) cause they come from way up in the block.

This might not be the cause for the regression though. From @ThomasRaoux on discord image

it seems like the additional address computation is breaking the alignment propagation.

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 33 (29 by maintainers)

Commits related to this issue

Most upvoted comments

Not yet. The LLVM changes haven’t come through. I am tracking it

The fixes are merged in LLVM upstream at respectively:

  • 68e1aef68e40452b6c176b25e67c13a0359c96ca and
  • e02d4142dca48d9359ad304fec629ea3d7e6924c

Like Mahesh said, we need to bump LLVM in IREE to get them, then disable the WAR.

I think this is waiting on integrate for the fixes to land in IREE (not sure if it is submitted upstream yet). Once that comes into IREE, i will flip the WAR and check for regressions.

Agree. @allieculp could you duplicate that issue specifically for the spir-v issue and make that one a P0.

I’d rather keep this as the P0 here. (Also, its P0 but I hope it doesnt mean you have to look at this during vacation, 😛 ). The issue to summarize is

  1. CUDA needs a workaround due to regressions with the new way of handling memrefs (which is consistent across all backends).
  2. SPIR-V does not need the workaround. Today the transform dialect path just says GPU for SPIR-V and CUDA. It’d be really good to just drop the work around for the CUDA side since right now it is on its own island w.r.t memref handling. Dont know what effects/divergence it will have over time. So that is the reason for me to keep the priority high (with the understanding that it might take time to fix). It’d be really undesirable to have a workaround for a workaround to carve out a special path for SPIR-V. This can spiral very fast.

My recommendation is to try to address the performance issue with CUDA (I think the real issue is alignment propagation in LLVM. We should really have alignment propagation in MLIR which we’ve discussed many times before, but if we controlled the alignment propagation in MLIR this regression wouldnt have happened). I am not sure what the timelines for P0 are… but it would be a bad result if we find a short term solution for the problem here.

That should be kind of a small enough repro?

Unfortunately no, unless I miss some flags with iree-compile/iree-opt.

I can reproduce what the address extraction does just fine, but I don’t see where in the pipeline we go south. I.e., what this pass does that prevents the other passes to do their job.

Is there a sort of iree-compile -start-before=<pass-name> that would start the compiler pipeline at <pass-name> and continue from there? (We have that at the llc level and that’s super convenient to reproduce this kind of issues.)

If you have the iree-compile command handy, that should be enough otherwise I’ll have to dig them up from the benchmark suite.

Alternatively I could build manually the pipeline with iree-opt, but since you don’t have all the passes in the dumps, I’m missing some passes to run.

This is the input to the Clips model that is in Linalg dialect.

The command line I used is

iree-compiler --iree-hal-target-backends=cuda --iree-input-type=mhlo --iree-hal-cuda-llvm-target-arch=sm_80 --output-format=vm-asm <input>

This wont work on ToT with IREE since I added a WAR for CUDA. So you probably need to set

to true to repro the issue.

You could also start with

iree-opt --pass-pipeline="builtin.module(func.func(extract-address-computation-gpu, memref-expand, expand-strided-metadata, loop-invariant-code-motion, decompose-affine-ops, cse, loop-invariant-code-motion, lower-affine

which is basically all these passes. on this function

module {
  func.func @forward_dispatch_11_batch_matmul_12x77x77x64() {
    %c3 = arith.constant 3 : index
    %c-3 = arith.constant -3 : index
    %c-32 = arith.constant -32 : index
    %c4 = arith.constant 4 : index
    %c8 = arith.constant 8 : index
    %c1 = arith.constant 1 : index
    %c77 = arith.constant 77 : index
    %c128 = arith.constant 128 : index
    %c32 = arith.constant 32 : index
    %c993664 = arith.constant 993664 : index
    %c1230208 = arith.constant 1230208 : index
    %c946176 = arith.constant 946176 : index
    %c969920 = arith.constant 969920 : index
    %c1466752 = arith.constant 1466752 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c64 = arith.constant 64 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() : memref<1x32x81xf32, #gpu.address_space<workgroup>>
    %alloc_0 = memref.alloc() : memref<1x32x36xf32, #gpu.address_space<workgroup>>
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %0 = arith.muli %workgroup_id_y, %c-32 : index
    %1 = arith.addi %0, %c77 : index
    %2 = arith.cmpi slt, %1, %c32 : index
    %3 = arith.select %2, %1, %c32 : index
    %4 = arith.cmpi eq, %3, %c32 : index
    scf.if %4 {
      %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c993664) flags(ReadOnly) : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
      memref.assume_alignment %5, 64 : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
      %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c1230208) flags(ReadOnly) : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
      memref.assume_alignment %6, 64 : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
      %7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c946176) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 236544>>
      memref.assume_alignment %7, 64 : memref<77x77xf32, strided<[77, 1], offset: 236544>>
      %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c969920) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 242480>>
      memref.assume_alignment %8, 64 : memref<77x77xf32, strided<[77, 1], offset: 242480>>
      %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c1466752) : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
      memref.assume_alignment %9, 64 : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
      %workgroup_id_z = hal.interface.workgroup.id[2] : index
      %workgroup_id_x = hal.interface.workgroup.id[0] : index
      %10 = gpu.thread_id  x
      %11 = gpu.thread_id  y
      %12 = arith.muli %10, %c-3 : index
      %13 = arith.addi %12, %c77 : index
      %14 = arith.cmpi slt, %13, %c3 : index
      %15 = arith.select %14, %13, %c3 : index
      %16 = arith.cmpi slt, %15, %c0 : index
      %17 = arith.select %16, %c0, %15 : index
      %18 = arith.muli %11, %c4 : index
      %19 = arith.muli %workgroup_id_y, %c32 : index
      %20 = arith.addi %18, %19 : index
      %21 = arith.muli %10, %c3 : index
      %22 = arith.muli %workgroup_id_x, %c128 : index
      %23 = arith.addi %21, %22 : index
      scf.for %arg0 = %c0 to %c4 step %c1 {
        scf.for %arg1 = %c0 to %17 step %c1 {
          %26 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%20, %arg0]
          %27 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%23, %arg1]
          memref.store %cst, %9[%workgroup_id_z, %26, %27] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
        }
      }
      %24 = gpu.thread_id  z
      %25 = arith.muli %10, %c4 : index
      scf.for %arg0 = %c0 to %c64 step %c32 {
        gpu.barrier
        scf.for %arg1 = %24 to %c1 step %c1 {
          %26 = arith.addi %workgroup_id_z, %arg1 : index
          scf.for %arg2 = %11 to %c32 step %c8 {
            %27 = arith.addi %arg2, %19 : index
            scf.for %arg3 = %25 to %c32 step %c128 {
              %28 = arith.addi %arg0, %arg3 : index
              scf.for %arg4 = %c0 to %c4 step %c1 {
                %29 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%28, %arg4]
                %30 = memref.load %5[%26, %27, %29] : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
                %31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg3, %arg4]
                memref.store %30, %alloc_0[%arg1, %arg2, %31] : memref<1x32x36xf32, #gpu.address_space<workgroup>>
              }
            }
          }
        }
        scf.for %arg1 = %24 to %c1 step %c1 {
          %26 = arith.addi %workgroup_id_z, %arg1 : index
          scf.for %arg2 = %11 to %c32 step %c8 {
            %27 = arith.addi %arg0, %arg2 : index
            scf.for %arg3 = %25 to %c77 step %c128 {
              %28 = arith.subi %c77, %arg3 : index
              %29 = arith.cmpi slt, %28, %c4 : index
              %30 = arith.select %29, %28, %c4 : index
              %31 = arith.addi %arg3, %22 : index
              scf.for %arg4 = %c0 to %30 step %c1 {
                %32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%31, %arg4]
                %33 = memref.load %6[%26, %27, %32] : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
                %34 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg3, %arg4]
                memref.store %33, %alloc[%arg1, %arg2, %34] : memref<1x32x81xf32, #gpu.address_space<workgroup>>
              }
            }
          }
        }
        gpu.barrier
        scf.for %arg1 = %c0 to %c4 step %c1 {
          scf.for %arg2 = %c0 to %17 step %c1 {
            scf.for %arg3 = %c0 to %c32 step %c1 {
              %26 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%18, %arg1]
              %27 = memref.load %alloc_0[%c0, %26, %arg3] : memref<1x32x36xf32, #gpu.address_space<workgroup>>
              %28 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%21, %arg2]
              %29 = memref.load %alloc[%c0, %arg3, %28] : memref<1x32x81xf32, #gpu.address_space<workgroup>>
              %30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%20, %arg1]
              %31 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%23, %arg2]
              %32 = memref.load %9[%workgroup_id_z, %30, %31] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
              %33 = arith.mulf %27, %29 : f32
              %34 = arith.addf %32, %33 : f32
              %35 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%20, %arg1]
              %36 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%23, %arg2]
              memref.store %34, %9[%workgroup_id_z, %35, %36] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
            }
          }
        }
      }
      scf.for %arg0 = %c0 to %c4 step %c1 {
        scf.for %arg1 = %c0 to %17 step %c1 {
          %26 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%20, %arg0]
          %27 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%23, %arg1]
          %28 = memref.load %7[%26, %27] : memref<77x77xf32, strided<[77, 1], offset: 236544>>
          %29 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%20, %arg0]
          %30 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%23, %arg1]
          %31 = memref.load %8[%29, %30] : memref<77x77xf32, strided<[77, 1], offset: 242480>>
          %32 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%20, %arg0]
          %33 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%23, %arg1]
          %34 = memref.load %9[%workgroup_id_z, %32, %33] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
          %35 = arith.addf %34, %28 : f32
          %36 = arith.addf %35, %31 : f32
          %37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%20, %arg0]
          %38 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%23, %arg1]
          memref.store %36, %9[%workgroup_id_z, %37, %38] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
        }
      }
    } else {
      %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c993664) flags(ReadOnly) : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
      memref.assume_alignment %5, 64 : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
      %6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c1230208) flags(ReadOnly) : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
      memref.assume_alignment %6, 64 : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
      %7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c946176) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 236544>>
      memref.assume_alignment %7, 64 : memref<77x77xf32, strided<[77, 1], offset: 236544>>
      %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c969920) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 242480>>
      memref.assume_alignment %8, 64 : memref<77x77xf32, strided<[77, 1], offset: 242480>>
      %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c1466752) : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
      memref.assume_alignment %9, 64 : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
      %workgroup_id_z = hal.interface.workgroup.id[2] : index
      %workgroup_id_x = hal.interface.workgroup.id[0] : index
      %10 = gpu.thread_id  x
      %11 = gpu.thread_id  y
      %12 = arith.cmpi sle, %3, %c0 : index
      %13 = arith.subi %c0, %3 : index
      %14 = arith.subi %3, %c1 : index
      %15 = arith.select %12, %13, %14 : index
      %16 = arith.divsi %15, %c8 : index
      %17 = arith.subi %c0, %16 : index
      %18 = arith.addi %16, %c1 : index
      %19 = arith.select %12, %17, %18 : index
      %20 = arith.muli %11, %19 : index
      %21 = arith.subi %3, %20 : index
      %22 = arith.cmpi slt, %21, %19 : index
      %23 = arith.select %22, %21, %19 : index
      %24 = arith.cmpi slt, %23, %c0 : index
      %25 = arith.select %24, %c0, %23 : index
      %26 = arith.muli %10, %c-3 : index
      %27 = arith.addi %26, %c77 : index
      %28 = arith.cmpi slt, %27, %c3 : index
      %29 = arith.select %28, %27, %c3 : index
      %30 = arith.cmpi slt, %29, %c0 : index
      %31 = arith.select %30, %c0, %29 : index
      %32 = arith.muli %workgroup_id_y, %c32 : index
      %33 = arith.addi %20, %32 : index
      %34 = arith.muli %10, %c3 : index
      %35 = arith.muli %workgroup_id_x, %c128 : index
      %36 = arith.addi %34, %35 : index
      scf.for %arg0 = %c0 to %25 step %c1 {
        scf.for %arg1 = %c0 to %31 step %c1 {
          %37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg0]
          %38 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg1]
          memref.store %cst, %9[%workgroup_id_z, %37, %38] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
        }
      }
      scf.for %arg0 = %c0 to %c64 step %c32 {
        scf.for %arg1 = %c0 to %25 step %c1 {
          scf.for %arg2 = %c0 to %31 step %c1 {
            scf.for %arg3 = %c0 to %c32 step %c1 {
              %37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg1]
              %38 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %arg3]
              %39 = memref.load %5[%workgroup_id_z, %37, %38] : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
              %40 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %arg3]
              %41 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg2]
              %42 = memref.load %6[%workgroup_id_z, %40, %41] : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
              %43 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg1]
              %44 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg2]
              %45 = memref.load %9[%workgroup_id_z, %43, %44] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
              %46 = arith.mulf %39, %42 : f32
              %47 = arith.addf %45, %46 : f32
              %48 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg1]
              %49 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg2]
              memref.store %47, %9[%workgroup_id_z, %48, %49] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
            }
          }
        }
      }
      scf.for %arg0 = %c0 to %25 step %c1 {
        scf.for %arg1 = %c0 to %31 step %c1 {
          %37 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg0]
          %38 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg1]
          %39 = memref.load %7[%37, %38] : memref<77x77xf32, strided<[77, 1], offset: 236544>>
          %40 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg0]
          %41 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg1]
          %42 = memref.load %8[%40, %41] : memref<77x77xf32, strided<[77, 1], offset: 242480>>
          %43 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg0]
          %44 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg1]
          %45 = memref.load %9[%workgroup_id_z, %43, %44] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
          %46 = arith.addf %45, %39 : f32
          %47 = arith.addf %46, %42 : f32
          %48 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%33, %arg0]
          %49 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%36, %arg1]
          memref.store %47, %9[%workgroup_id_z, %48, %49] : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
        }
      }
    }
    return
  }
}

(which is the IR for the dispatch_11 from the gist above before expand-address-computation-gpu pass.)