iree: [llvm-cpu] problem lowering Gelu to llvm
What happened?
small repro IR:
module @proc_args {
func.func @dynamic_dim(%arg0: tensor<2x4096x9728xf32>) -> tensor<2x4096x9728xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x4096x9728xf32> -> !torch.vtensor<[2,4096,9728],f32>
%1 = call @compute(%0) : (!torch.vtensor<[2,4096,9728],f32>) -> !torch.vtensor<[2,4096,9728],f32>
%2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[2,4096,9728],f32> -> tensor<2x4096x9728xf32>
return %2 : tensor<2x4096x9728xf32>
}
func.func private @compute(%arg0: !torch.vtensor<[2,4096,9728],f32>) -> !torch.vtensor<[2,4096,9728],f32> {
%str = torch.constant.str "tanh"
%0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[2,4096,9728],f32>, !torch.str -> !torch.vtensor<[2,4096,9728],f32>
return %0 : !torch.vtensor<[2,4096,9728],f32>
}
}
IR prior to error:
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @dynamic_dim$async_dispatch_0_generic_79691776_f32() {
%cst = arith.constant dense<1.19825836E-6> : vector<8xf32> loc(unknown)
%cst_0 = arith.constant dense<1.18534706E-4> : vector<8xf32> loc(unknown)
%cst_1 = arith.constant dense<0.00226843474> : vector<8xf32> loc(unknown)
%cst_2 = arith.constant dense<0.00489352504> : vector<8xf32> loc(unknown)
%cst_3 = arith.constant dense<-2.76076837E-16> : vector<8xf32> loc(unknown)
%cst_4 = arith.constant dense<2.00018794E-13> : vector<8xf32> loc(unknown)
%cst_5 = arith.constant dense<-8.60467184E-11> : vector<8xf32> loc(unknown)
%cst_6 = arith.constant dense<5.12229725E-8> : vector<8xf32> loc(unknown)
%cst_7 = arith.constant dense<1.48572235E-5> : vector<8xf32> loc(unknown)
%cst_8 = arith.constant dense<6.37261954E-4> : vector<8xf32> loc(unknown)
%cst_9 = arith.constant dense<0.00489352457> : vector<8xf32> loc(unknown)
%cst_10 = arith.constant dense<4.000000e-04> : vector<8xf32> loc(unknown)
%cst_11 = arith.constant dense<7.99881172> : vector<8xf32> loc(unknown)
%cst_12 = arith.constant dense<-7.99881172> : vector<8xf32> loc(unknown)
%cst_13 = arith.constant dense<5.000000e-01> : vector<8xf32> loc(unknown)
%cst_14 = arith.constant dense<1.000000e+00> : vector<8xf32> loc(unknown)
%cst_15 = arith.constant dense<7.977240e-01> : vector<8xf32> loc(unknown)
%cst_16 = arith.constant dense<4.471500e-02> : vector<8xf32> loc(unknown)
%cst_17 = arith.constant dense<3> : vector<8xi64> loc(unknown)
%c8 = arith.constant 8 : index loc(unknown)
%c4096 = arith.constant 4096 : index loc(unknown)
%c0 = arith.constant 0 : index loc(unknown)
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
memref.assume_alignment %0, 64 : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
memref.assume_alignment %1, 64 : memref<79691776xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%workgroup_id_x = hal.interface.workgroup.id[0] : index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
cf.br ^bb1(%c0 : index) loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
^bb1(%2: index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))): // 2 preds: ^bb0, ^bb2
%3 = arith.cmpi slt, %2, %c4096 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
cf.cond_br %3, ^bb2, ^bb3 loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
^bb2: // pred: ^bb1
%4 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4096)>()[%2, %workgroup_id_x] loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%5 = vector.load %0[%4] : memref<79691776xf32>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%6 = math.fpowi %5, %cst_17 : vector<8xf32>, vector<8xi64> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%7 = arith.mulf %6, %cst_16 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%8 = arith.addf %5, %7 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%9 = arith.mulf %8, %cst_15 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%10 = arith.cmpf ult, %9, %cst_11 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%11 = arith.select %10, %9, %cst_11 : vector<8xi1>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%12 = arith.cmpf ugt, %11, %cst_12 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%13 = arith.select %12, %11, %cst_12 : vector<8xi1>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%14 = math.absf %9 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%15 = arith.cmpf olt, %14, %cst_10 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%16 = arith.mulf %13, %13 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%17 = math.fma %16, %cst_3, %cst_4 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%18 = math.fma %16, %17, %cst_5 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%19 = math.fma %16, %18, %cst_6 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%20 = math.fma %16, %19, %cst_7 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%21 = math.fma %16, %20, %cst_8 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%22 = math.fma %16, %21, %cst_9 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%23 = arith.mulf %13, %22 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%24 = math.fma %16, %cst, %cst_0 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%25 = math.fma %16, %24, %cst_1 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%26 = math.fma %16, %25, %cst_2 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%27 = arith.divf %23, %26 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%28 = arith.select %15, %13, %27 : vector<8xi1>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%29 = arith.addf %28, %cst_14 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%30 = arith.mulf %29, %cst_13 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%31 = arith.mulf %5, %30 : vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
vector.store %31, %1[%4] : memref<79691776xf32>, vector<8xf32> loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
%32 = arith.addi %2, %c8 : index loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
cf.br ^bb1(%32 : index) loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
^bb3: // pred: ^bb1
return loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
} loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
} loc(callsite("repro.mlir":10:10 at "repro.mlir":4:10))
the error:
repro.mlir:10:10: error: 'llvm.intr.powi' op operand #1 must be signless integer, but got 'vector<8xi64>'
%0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[2,4096,9728],f32>, !torch.str -> !torch.vtensor<[2,4096,9728],f32>
Is this just because its using a signed vector of ints? Do I just go through the passes until I see what produced that vector?
EDIT: looks like math.fpowi expects signed ints, so is this just an incorrect lowering somewhere?
About this issue
- Original URL
- State: closed
- Created 3 months ago
- Comments: 27 (25 by maintainers)
For this kind of operations, we want to lower it with polynomial approximation. Otherwise, it could lead to a libc call. So I think the proper solution is adding the approximation conversion, like what we’ve done for other power operations: https://github.com/llvm/llvm-project/blob/dcd0f2b6103072b74b446c2d1e9ecec60001a28c/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp#L205-L232
nice, that works, thanks. I’ll push a patch that bumps llvm and adds this
We need to add the patterns to https://github.com/openxla/iree/blob/190d9592d8b561b32b471638750f1c72e1ec562f/compiler/src/iree/compiler/Codegen/Common/PolynomialApproximationPass.cpp#L28-L31
I.e., call
populateExpandFPowIPattern(mathPatterns);
in the passerror: 'math.fpowi' op requires the same shape for all operands and results
.Okay, I think this explains why we have constant=3.
@bjacob thanks for the great input. After thinking a while, I think the expansion we want here is similar to what you’ve described. We want to rewrite fpowi into a few arith.mulf. The logic should be the same for constant integer exponent. We can embed the loop in c++ when it is compile-time constant; we can create a
scf.for
ops when it is not a compile-time constant. The patterns will be put to theExpandPatterns.cpp
in both cases. We want a iterative divide and conquer (O(log N)) for fpowi ops.https://paperswithcode.com/method/gelu
explains the op, and the approximation used. (kind of funny, computing is a lot of approximations to approximate approximations)