iree: Offset computation seems to introduce regressions on CUDA backends
PR #12874 introduces a consistent handling of subspan with offsets on all backends. This avoids relying on using GEPs to make the offsets work, since backends like SPIR-V and VMVX do not have GEP-like instructions.
Unfortunately this seems to introduce regressions on CUDA backend. Here is the IR after certain passes on Clips64 model before the change https://gist.github.com/MaheshRavishankar/7996d674d878c2154d74d97e843018a1 Here is the IR after the change https://gist.github.com/MaheshRavishankar/ee1d51447763d5fd8ee16f855b7bcc81
The issue here seems to be this kind of computation
^bb1: // pred: ^bb0
%5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c993664) flags(ReadOnly) : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
memref.assume_alignment %5, 64 : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>>
%6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c1230208) flags(ReadOnly) : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
memref.assume_alignment %6, 64 : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>>
%7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c946176) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 236544>>
memref.assume_alignment %7, 64 : memref<77x77xf32, strided<[77, 1], offset: 236544>>
%8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c969920) flags(ReadOnly) : memref<77x77xf32, strided<[77, 1], offset: 242480>>
memref.assume_alignment %8, 64 : memref<77x77xf32, strided<[77, 1], offset: 242480>>
%9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c1466752) : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
memref.assume_alignment %9, 64 : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>>
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%10 = gpu.thread_id x
%11 = gpu.thread_id y
%12 = arith.muli %10, %c-3 : index
%13 = arith.addi %12, %c77 : index
%14 = arith.cmpi slt, %13, %c3 : index
%15 = arith.select %14, %13, %c3 : index
%16 = arith.cmpi slt, %15, %c0 : index
%17 = arith.select %16, %c0, %15 : index
%18 = arith.muli %11, %c4 : index
%19 = arith.muli %workgroup_id_y, %c32 : index
%20 = arith.addi %18, %19 : index
%21 = arith.muli %10, %c3 : index
%22 = arith.muli %workgroup_id_x, %c128 : index
%23 = arith.addi %21, %22 : index
%base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %9 : memref<12x77x77xf32, strided<[5929, 77, 1], offset: 366688>> -> memref<f32>, index, index, index, index, index, index, index
%24 = arith.muli %workgroup_id_z, %c5929 : index
%25 = arith.addi %24, %c366688 : index
%26 = arith.muli %20, %c77 : index
%27 = arith.addi %25, %26 : index
%28 = arith.addi %27, %23 : index
cf.br ^bb2(%c0 : index)
^bb2(%29: index): // 2 preds: ^bb1, ^bb6
%30 = arith.cmpi slt, %29, %c4 : index
cf.cond_br %30, ^bb3, ^bb7
^bb3: // pred: ^bb2
%31 = arith.muli %29, %c77 : index
%32 = arith.addi %28, %31 : index
cf.br ^bb4(%c0 : index)
^bb4(%33: index): // 2 preds: ^bb3, ^bb5
%34 = arith.cmpi slt, %33, %17 : index
cf.cond_br %34, ^bb5, ^bb6
^bb5: // pred: ^bb4
%35 = arith.addi %32, %33 : index
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [%35], sizes: [1, 1, 1], strides: [5929, 77, 1] : memref<f32> to memref<1x1x1xf32, strided<[5929, 77, 1], offset: ?>>
memref.store %cst, %reinterpret_cast[%c0, %c0, %c0] : memref<1x1x1xf32, strided<[5929, 77, 1], offset: ?>>
%36 = arith.addi %33, %c1 : index
cf.br ^bb4(%36 : index)
^bb6: // pred: ^bb4
%37 = arith.addi %29, %c1 : index
cf.br ^bb2(%37 : index)
^bb7: // pred: ^bb2
%38 = gpu.thread_id z
%39 = arith.muli %10, %c4 : index
%base_buffer_1, %offset_2, %sizes_3:3, %strides_4:3 = memref.extract_strided_metadata %5 : memref<12x77x64xf32, strided<[4928, 64, 1], offset: 248416>> -> memref<f32>, index, index, index, index, index, index, index
%base_buffer_5, %offset_6, %sizes_7:3, %strides_8:3 = memref.extract_strided_metadata %6 : memref<12x64x77xf32, strided<[4928, 77, 1], offset: 307552>> -> memref<f32>, index, index, index, index, index, index, index
%40 = arith.muli %18, %c36 : index
cf.br ^bb8(%c0 : index)
^bb8(%41: index): // 2 preds: ^bb7, ^bb41
%42 = arith.cmpi slt, %41, %c64 : index
cf.cond_br %42, ^bb9, ^bb42
^bb41: // pred: ^bb33
%123 = arith.addi %41, %c32 : index
cf.br ^bb8(%123 : index)
^bb42: // pred: ^bb8
%base_buffer_16, %offset_17, %sizes_18:2, %strides_19:2 = memref.extract_strided_metadata %7 : memref<77x77xf32, strided<[77, 1], offset: 236544>> -> memref<f32>, index, index, index, index, index
%base_buffer_20, %offset_21, %sizes_22:2, %strides_23:2 = memref.extract_strided_metadata %8 : memref<77x77xf32, strided<[77, 1], offset: 242480>> -> memref<f32>, index, index, index, index, index
%124 = arith.addi %26, %c236544 : index
%125 = arith.addi %124, %23 : index
%126 = arith.addi %26, %c242480 : index
%127 = arith.addi %126, %23 : index
The last 4 instructions are mostly hoistable (if not all of them) cause they come from way up in the block.
This might not be the cause for the regression though. From @ThomasRaoux on discord
it seems like the additional address computation is breaking the alignment propagation.
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 33 (29 by maintainers)
Commits related to this issue
- Make subspan handling consistent across backends. (#12874) The VMVX and SPIR-V sides have been changed to treat memref result type of hal.interface.binding.subspan with offsets from ``` hal.inter... — committed to iree-org/iree by MaheshRavishankar a year ago
- Make subspan handling consistent across backends. (#12874) The VMVX and SPIR-V sides have been changed to treat memref result type of hal.interface.binding.subspan with offsets from ``` hal.inter... — committed to iree-org/iree by MaheshRavishankar a year ago
- Make handling subspan with offsets and memref return types for CUDA consistent with other backends. (#13378) With https://github.com/iree-org/iree-llvm-fork/commit/e02d4142dca48d9359ad304fec629ea3d7e... — committed to iree-org/iree by MaheshRavishankar a year ago
- Make subspan handling consistent across backends. (#12874) The VMVX and SPIR-V sides have been changed to treat memref result type of hal.interface.binding.subspan with offsets from ``` hal.inter... — committed to NatashaKnk/iree by MaheshRavishankar a year ago
- Make handling subspan with offsets and memref return types for CUDA consistent with other backends. (#13378) With https://github.com/iree-org/iree-llvm-fork/commit/e02d4142dca48d9359ad304fec629ea3d7e... — committed to NatashaKnk/iree by MaheshRavishankar a year ago
Not yet. The LLVM changes haven’t come through. I am tracking it
The fixes are merged in LLVM upstream at respectively:
Like Mahesh said, we need to bump LLVM in IREE to get them, then disable the WAR.
I think this is waiting on integrate for the fixes to land in IREE (not sure if it is submitted upstream yet). Once that comes into IREE, i will flip the WAR and check for regressions.
I’d rather keep this as the P0 here. (Also, its P0 but I hope it doesnt mean you have to look at this during vacation, 😛 ). The issue to summarize is
memref
s (which is consistent across all backends).My recommendation is to try to address the performance issue with CUDA (I think the real issue is alignment propagation in LLVM. We should really have alignment propagation in MLIR which we’ve discussed many times before, but if we controlled the alignment propagation in MLIR this regression wouldnt have happened). I am not sure what the timelines for P0 are… but it would be a bad result if we find a short term solution for the problem here.
This is the input to the Clips model that is in Linalg dialect.
The command line I used is
This wont work on ToT with IREE since I added a WAR for CUDA. So you probably need to set
to
true
to repro the issue.You could also start with
which is basically all these passes. on this function
(which is the IR for the
dispatch_11
from the gist above beforeexpand-address-computation-gpu
pass.)