iree: stablehlo.sort pathologically slow
What happened?
Compile/run IR below. The result takes like 10 seconds to run on my A100. That is over 1000x off in performance.
$ iree-compile --iree-hal-target-backends=cuda --iree-input-type=xla --iree-hal-cuda-llvm-target-arch=sm_80 sort.mlir >sort.vmfb
$ iree-benchmark-module --device=cuda --module=sort.vmfb --function=main --input=4x8000xf32=0 --input=4x8000xi32=0 --benchmark_repetitions=2
module {
func.func @main(%arg0: tensor<4x8000xf32>, %arg1: tensor<4x8000xi32>) -> (tensor<4x8000xf32>, tensor<4x8000xi32>) {
%0:2 = "stablehlo.sort"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<i32>, %arg5: tensor<i32>):
%1 = stablehlo.compare GT, %arg2, %arg3, TOTALORDER : (tensor<f32>, tensor<f32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) {dimension = 1 : i64, is_stable = true} : (tensor<4x8000xf32>, tensor<4x8000xi32>) -> (tensor<4x8000xf32>, tensor<4x8000xi32>)
return %0#0, %0#1 : tensor<4x8000xf32>, tensor<4x8000xi32>
}
}
Steps to reproduce your issue
See above
What component(s) does this issue relate to?
Compiler
Version information
No response
Additional context
No response
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 26 (18 by maintainers)
@allieculp I posted an RFC for adding a topk op on the stablehlo github and presented it during the last openxla community meeting. I don’t have an ETA for the new op being available in IREE, but this is not on the critical path given the workarounds from Natasha and Rob.
Yes, I am not talking in tree. The custom ptx in the nvgpu plugin is what I am saying above. I mentioned it internally, the topk implementation isnt done using these interfaces, and was a one-off (I did push for using interfaces then, but got push back). Having a sort one-off would be sad… So if we want to solve it and be done with it, I’d happily help, but this will require some heavy lifting. If we arent prepared to do that now then we might as well go with the easiest solution to avoid building another “heavy-weight one-off”.
oh yeah, more saying that pattern matching this to the linalgext op may be the quick workaround vs needing to change the original model to use chlo (if it’s not already - someone needs to look) - iota + broadcast + sort seems reasonable
The
iree_linalg_ext.sort
operation implements theTilingInterface
. Its a small step from there to thePartialReductionOpInterface
. That will allow handling of sort similar to split-k… both the interface and implementation/use in IREE need work…