iree: stablehlo.sort pathologically slow

What happened?

Compile/run IR below. The result takes like 10 seconds to run on my A100. That is over 1000x off in performance.

$ iree-compile --iree-hal-target-backends=cuda --iree-input-type=xla --iree-hal-cuda-llvm-target-arch=sm_80 sort.mlir >sort.vmfb
$ iree-benchmark-module --device=cuda --module=sort.vmfb --function=main --input=4x8000xf32=0 --input=4x8000xi32=0 --benchmark_repetitions=2
module {
  func.func @main(%arg0: tensor<4x8000xf32>, %arg1: tensor<4x8000xi32>) -> (tensor<4x8000xf32>, tensor<4x8000xi32>) {
    %0:2 = "stablehlo.sort"(%arg0, %arg1) ({
    ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<i32>, %arg5: tensor<i32>):
      %1 = stablehlo.compare  GT, %arg2, %arg3,  TOTALORDER : (tensor<f32>, tensor<f32>) -> tensor<i1>
      stablehlo.return %1 : tensor<i1>
    }) {dimension = 1 : i64, is_stable = true} : (tensor<4x8000xf32>, tensor<4x8000xi32>) -> (tensor<4x8000xf32>, tensor<4x8000xi32>)
    return %0#0, %0#1 : tensor<4x8000xf32>, tensor<4x8000xi32>
  }
}

Steps to reproduce your issue

See above

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 26 (18 by maintainers)

Most upvoted comments

@kuhar Do you have any update on this issue?

@allieculp I posted an RFC for adding a topk op on the stablehlo github and presented it during the last openxla community meeting. I don’t have an ETA for the new op being available in IREE, but this is not on the critical path given the workarounds from Natasha and Rob.

hrrmmm we can’t close this issue if the solution is “write some ptx” - may unblock the user here but definitely still a major issue for the platform as vulkan/metal/cpu will have extremely slow sorts and we don’t want to be making one-off workarounds in the core platform in lieu of actually solving the problem better. if the heuristic makes things better we should do that, and then maybe if the user needs even better performance they can add custom ptx in the nvgpu plugin.

Yes, I am not talking in tree. The custom ptx in the nvgpu plugin is what I am saying above. I mentioned it internally, the topk implementation isnt done using these interfaces, and was a one-off (I did push for using interfaces then, but got push back). Having a sort one-off would be sad… So if we want to solve it and be done with it, I’d happily help, but this will require some heavy lifting. If we arent prepared to do that now then we might as well go with the easiest solution to avoid building another “heavy-weight one-off”.

oh yeah, more saying that pattern matching this to the linalgext op may be the quick workaround vs needing to change the original model to use chlo (if it’s not already - someone needs to look) - iota + broadcast + sort seems reasonable

The iree_linalg_ext.sort operation implements the TilingInterface. Its a small step from there to the PartialReductionOpInterface. That will allow handling of sort similar to split-k… both the interface and implementation/use in IREE need work…