iree: Unsupported mhlo.sort op lowering

%443:2 = "mhlo.sort"(%441, %442) ( {
^bb0(%arg17: tensor<i32>, %arg18: tensor<i32>, %arg19: tensor<i32>, %arg20: tensor<i32>):  // no predecessors
  %780 = "mhlo.compare"(%arg17, %arg18) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
  "mhlo.return"(%780) : (tensor<i1>) -> ()
}) {dimension = 0 : i64, is_stable = false} : (tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>)

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 59 (55 by maintainers)

Commits related to this issue

Most upvoted comments

This is the IR that the pass is spitting out, which fails verification:

  builtin.func @sort(%arg0: tensor<1x134832xf32>, %arg1: tensor<1x134832xi32>) -> (tensor<1x134832xf32>, tensor<1x134832xi32>) {
    %0 = mhlo.constant dense<2147483647> : tensor<ui32>
    %1 = mhlo.constant dense<0> : tensor<i32>
    %2:2 = linalg_ext.sort dimension(1) outs(%arg0, %arg1 : tensor<1x134832xf32>, tensor<1x134832xi32>)  {
    ^bb0(%arg2: f32, %arg3: f32, %arg4: i32, %arg5: i32):  // no predecessors
      %3 = bitcast %arg2 : f32 to i32
      %4 = tensor.extract %1[] : tensor<i32>
      %5 = cmpi slt, %3, %4 : i32
      %6 = bitcast %arg2 : f32 to ui32
      %7 = tensor.extract %0[] : tensor<ui32>
      %8 = subi %7, %6 : ui32
      %9 = bitcast %8 : ui32 to i32
      %10 = "mhlo.select"(%5, %9, %3) : (i1, i32, i32) -> tensor<i32>
      %11 = bitcast %arg3 : f32 to i32
      %12 = tensor.extract %1[] : tensor<i32>
      %13 = cmpi slt, %11, %12 : i32
      %14 = bitcast %arg3 : f32 to ui32
      %15 = tensor.extract %0[] : tensor<ui32>
      %16 = subi %15, %14 : ui32
      %17 = bitcast %16 : ui32 to i32
      %18 = "mhlo.select"(%13, %17, %11) : (i1, i32, i32) -> tensor<i32>
      %19 = tensor.extract %10[] : tensor<i32>
      %20 = tensor.extract %18[] : tensor<i32>
      %21 = cmpi slt, %19, %20 : i32
      linalg_ext.yield %21 : i1
    } -> tensor<1x134832xf32>, tensor<1x134832xi32>
    return %2#0, %2#1 : tensor<1x134832xf32>, tensor<1x134832xi32>
  }

@MaheshRavishankar preorder traversal is documented behavior of dialect conversion: https://mlir.llvm.org/docs/DialectConversion/#modes-of-conversion

mhlo.sort is not IsolatedFromAbove. Now I see whats going on. This is all sort of badness coming from conflating scalars with 0-rank tensors!

The solution is while lowering the region use getUsedValuesDefinedAbove to get all the implicitly capture values. For now just check that these are all constants and lower them to constant scalars within the body of the linalg_ext.sort op.

This is really broken op semantics on mhlo.sort. Its a valid op just because MHLO decided it is a valid operation. This should have not been allowed in the first place.

I will take a look after landing https://github.com/google/iree/pull/6562. It changes the pipeline.

I hate it when I create really expensive identity ops 😃