iree: [llvm-cpu] problem lowering Gelu to llvm

What happened?

small repro IR:

module @proc_args {
  func.func @dynamic_dim(%arg0: tensor<2x4096x9728xf32>) -> tensor<2x4096x9728xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<2x4096x9728xf32> -> !torch.vtensor<[2,4096,9728],f32>
    %1 = call @compute(%0) : (!torch.vtensor<[2,4096,9728],f32>) -> !torch.vtensor<[2,4096,9728],f32>
    %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[2,4096,9728],f32> -> tensor<2x4096x9728xf32>
    return %2 : tensor<2x4096x9728xf32>
  }
  func.func private @compute(%arg0: !torch.vtensor<[2,4096,9728],f32>) -> !torch.vtensor<[2,4096,9728],f32> {
    %str = torch.constant.str "tanh"
    %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[2,4096,9728],f32>, !torch.str -> !torch.vtensor<[2,4096,9728],f32>
    return %0 : !torch.vtensor<[2,4096,9728],f32>
  }
}

IR prior to error:

// -----// IR Dump After CSE (cse) //----- //
module {
  func.func @dynamic_dim$async_dispatch_0_generic_79691776_f32() {
    %cst = arith.constant dense<1.19825836E-6> : vector<8xf32> loc(unknown)
    %cst_0 = arith.constant dense<1.18534706E-4> : vector<8xf32> loc(unknown)
    %cst_1 = arith.constant dense<0.00226843474> : vector<8xf32> loc(unknown)
    %cst_2 = arith.constant dense<0.00489352504> : vector<8xf32> loc(unknown)
    %cst_3 = arith.constant dense<-2.76076837E-16> : vector<8xf32> loc(unknown)
    %cst_4 = arith.constant dense<2.00018794E-13> : vector<8xf32> loc(unknown)
    %cst_5 = arith.constant dense<-8.60467184E-11> : vector<8xf32> loc(unknown)
    %cst_6 = arith.constant dense<5.12229725E-8> : vector<8xf32> loc(unknown)
    %cst_7 = arith.constant dense<1.48572235E-5> : vector<8xf32> loc(unknown)
    %cst_8 = arith.constant dense<6.37261954E-4> : vector<8xf32> loc(unknown)
    %cst_9 = arith.constant dense<0.00489352457> : vector<8xf32> loc(unknown)
    %cst_10 = arith.constant dense<4.000000e-04> : vector<8xf32> loc(unknown)
    %cst_11 = arith.constant dense<7.99881172> : vector<8xf32> loc(unknown)
    %cst_12 = arith.constant dense<-7.99881172> : vector<8xf32> loc(unknown)
    %cst_13 = arith.constant dense<5.000000e-01> : vector<8xf32> loc(unknown)
    %cst_14 = arith.constant dense<1.000000e+00> : vector<8xf32> loc(unknown)
    %cst_15 = arith.constant dense<7.977240e-01> : vector<8xf32> loc(unknown)
    %cst_16 = arith.constant dense<4.471500e-02> : vector<8xf32> loc(unknown)
    %cst_17 = arith.constant dense<3> : vector<8xi64> loc(unknown)
    %c8 = arith.constant 8 : index loc(unknown)
    %c4096 = arith.constant 4096 : index loc(unknown)
    %c0 = arith.constant 0 : index loc(unknown)
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    memref.assume_alignment %0, 64 : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    memref.assume_alignment %1, 64 : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %workgroup_id_x = hal.interface.workgroup.id[0] : index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    cf.br ^bb1(%c0 : index) loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
  ^bb1(%2: index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))):  // 2 preds: ^bb0, ^bb2
    %3 = arith.cmpi slt, %2, %c4096 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    cf.cond_br %3, ^bb2, ^bb3 loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
  ^bb2:  // pred: ^bb1
    %4 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4096)>()[%2, %workgroup_id_x] loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %5 = vector.load %0[%4] : memref<79691776xf32>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %6 = math.fpowi %5, %cst_17 : vector<8xf32>, vector<8xi64> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %7 = arith.mulf %6, %cst_16 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %8 = arith.addf %5, %7 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %9 = arith.mulf %8, %cst_15 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %10 = arith.cmpf ult, %9, %cst_11 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %11 = arith.select %10, %9, %cst_11 : vector<8xi1>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %12 = arith.cmpf ugt, %11, %cst_12 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %13 = arith.select %12, %11, %cst_12 : vector<8xi1>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %14 = math.absf %9 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %15 = arith.cmpf olt, %14, %cst_10 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %16 = arith.mulf %13, %13 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %17 = math.fma %16, %cst_3, %cst_4 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %18 = math.fma %16, %17, %cst_5 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %19 = math.fma %16, %18, %cst_6 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %20 = math.fma %16, %19, %cst_7 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %21 = math.fma %16, %20, %cst_8 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %22 = math.fma %16, %21, %cst_9 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %23 = arith.mulf %13, %22 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %24 = math.fma %16, %cst, %cst_0 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %25 = math.fma %16, %24, %cst_1 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %26 = math.fma %16, %25, %cst_2 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %27 = arith.divf %23, %26 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %28 = arith.select %15, %13, %27 : vector<8xi1>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %29 = arith.addf %28, %cst_14 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %30 = arith.mulf %29, %cst_13 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %31 = arith.mulf %5, %30 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    vector.store %31, %1[%4] : memref<79691776xf32>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    %32 = arith.addi %2, %c8 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
    cf.br ^bb1(%32 : index) loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
  ^bb3:  // pred: ^bb1
    return loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
  } loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
} loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))

the error:

repro.mlir:10:10: error: 'llvm.intr.powi' op operand #1 must be signless integer, but got 'vector<8xi64>'
    %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[2,4096,9728],f32>, !torch.str -> !torch.vtensor<[2,4096,9728],f32>

Is this just because its using a signed vector of ints? Do I just go through the passes until I see what produced that vector?

EDIT: looks like math.fpowi expects signed ints, so is this just an incorrect lowering somewhere?

About this issue

  • Original URL
  • State: closed
  • Created 3 months ago
  • Comments: 27 (25 by maintainers)

Commits related to this issue

Most upvoted comments

For this kind of operations, we want to lower it with polynomial approximation. Otherwise, it could lead to a libc call. So I think the proper solution is adding the approximation conversion, like what we’ve done for other power operations: https://github.com/llvm/llvm-project/blob/dcd0f2b6103072b74b446c2d1e9ecec60001a28c/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp#L205-L232

nice, that works, thanks. I’ll push a patch that bumps llvm and adds this

error: 'math.fpowi' op requires the same shape for all operands and results.

https://paperswithcode.com/method/gelu

explains the op, and the approximation used. (kind of funny, computing is a lot of approximations to approximate approximations)

Okay, I think this explains why we have constant=3.

@bjacob thanks for the great input. After thinking a while, I think the expansion we want here is similar to what you’ve described. We want to rewrite fpowi into a few arith.mulf. The logic should be the same for constant integer exponent. We can embed the loop in c++ when it is compile-time constant; we can create a scf.for ops when it is not a compile-time constant. The patterns will be put to the ExpandPatterns.cpp in both cases. We want a iterative divide and conquer (O(log N)) for fpowi ops.

https://paperswithcode.com/method/gelu

explains the op, and the approximation used. (kind of funny, computing is a lot of approximations to approximate approximations)