iree: [CPU] LLaMA F32 execution crashes when data tiling enabled

We current get a segfault when running LLaMA on F32 and data tiling (not ukernels) is enabled. It only crashes when the input is 1x4096. It works for 1x8 and 1x64 inputs. It also works for all the shapes without data tiling enabled (IREE default).

I use the following compilation command:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake \
  --iree-input-type=stablehlo --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu \
  --iree-input-demote-f64-to-f32=false --iree-input-demote-i64-to-i32=false \
  --iree-llvmcpu-link-embedded=false --iree-opt-data-tiling=true --iree-llvmcpu-enable-microkernels=false \
  llama-linalg-dynamic.mlir -o llama-linalg-dynamic.vmfb

The crash seems to happen at the tensor.unpack operation in dispatch forward_dispatch_26_generic_32xDxD_f32:

hal.executable public @forward_dispatch_26 {
  hal.executable.variant public @system_elf_x86_64, target = <"llvm-cpu", "system-elf-x86_64", {cpu = "cascadelake", cpu_features = "+cmov,+mmx,+popcnt,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+avx,+avx2,+fma,+avx512f,+bmi,+bmi2,+aes,+pclmul,+avx512vl,+avx512bw,+avx512dq,+avx512cd,+avx512vnni,+adx,+clflushopt,+clwb,+cx16,+cx8,+crc32,+f16c,+fsgsbase,+fxsr,+invpcid,+lzcnt,+movbe,+pku,+prfchw,+rdrnd,+rdseed,+sahf,+x87,+xsave,+xsavec,+xsaveopt,+xsaves,+evex512", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", link_embedded = false, native_vector_size = 64 : index, target_triple = "x86_64-unknown-linux-gnu", ukernels = false}> {
    hal.executable.export public @forward_dispatch_26_generic_32xDxD_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 7, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
    ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @forward_dispatch_26_generic_32xDxD_f32() {
        %c32_i64 = arith.constant 32 : i64
        %c1 = arith.constant 1 : index
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant 11.3137083 : f32
        %c0 = arith.constant 0 : index
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = hal.interface.constant.load[2] : i32
        %3 = hal.interface.constant.load[3] : i32
        %4 = hal.interface.constant.load[4] : i32
        %5 = hal.interface.constant.load[5] : i32
        %6 = hal.interface.constant.load[6] : i32
        %7 = arith.extui %0 : i32 to i64
        %8 = arith.extui %1 : i32 to i64
        %9 = arith.shli %8, %c32_i64 : i64
        %10 = arith.ori %7, %9 : i64
        %11 = arith.index_castui %10 : i64 to index
        %12 = arith.index_castui %2 : i32 to index
        %13 = arith.extui %3 : i32 to i64
        %14 = arith.extui %4 : i32 to i64
        %15 = arith.shli %14, %c32_i64 : i64
        %16 = arith.ori %13, %15 : i64
        %17 = arith.index_castui %16 : i64 to index
        %18 = arith.extui %5 : i32 to i64
        %19 = arith.extui %6 : i32 to i64
        %20 = arith.shli %19, %c32_i64 : i64
        %21 = arith.ori %18, %20 : i64
        %22 = arith.index_castui %21 : i64 to index
        %23 = flow.dispatch.workload.ordinal %17, 0 : index
        %24 = flow.dispatch.workload.ordinal %22, 1 : index
        %25 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%11) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x2x?x16x16xf32>>{%23}
        %26 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x1x?x?xf32>>{%24, %24}
        %27 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%12) : !flow.dispatch.tensor<writeonly:tensor<32x?x?xf32>>{%24, %24}
        %28 = flow.dispatch.tensor.load %25, offsets = [0, 0, 0, 0, 0], sizes = [32, 2, %23, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x2x?x16x16xf32>>{%23} -> tensor<32x2x?x16x16xf32>
        %29 = flow.dispatch.tensor.load %26, offsets = [0, 0, 0, 0], sizes = [1, 1, %24, %24], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x?x?xf32>>{%24, %24} -> tensor<1x1x?x?xf32>
        %30 = tensor.empty(%24, %24) : tensor<32x?x?xf32>
        %unpack = tensor.unpack %28 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %30 : tensor<32x2x?x16x16xf32> -> tensor<32x?x?xf32>
        %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<32x?x?xf32>) outs(%30 : tensor<32x?x?xf32>) {
        ^bb0(%in: f32, %out: f32):
          %32 = linalg.index 1 : index
          %33 = linalg.index 2 : index
          %34 = arith.cmpi eq, %24, %c1 : index
          %35 = arith.select %34, %c0, %32 : index
          %36 = arith.select %34, %c0, %33 : index
          %extracted = tensor.extract %29[%c0, %c0, %35, %36] : tensor<1x1x?x?xf32>
          %37 = arith.addf %extracted, %cst : f32
          %38 = arith.divf %in, %cst_0 : f32
          %39 = arith.addf %38, %37 : f32
          linalg.yield %39 : f32
        } -> tensor<32x?x?xf32>
        flow.dispatch.tensor.store %31, %27, offsets = [0, 0, 0], sizes = [32, %24, %24], strides = [1, 1, 1] : tensor<32x?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<32x?x?xf32>>{%24, %24}
        return
      }
    }
  }
}

@MaheshRavishankar, let me know if you can spot what’s going on with the current IR or need the IR before dispatch creation/data tiling.

About this issue

  • Original URL
  • State: closed
  • Created 9 months ago
  • Comments: 30 (29 by maintainers)

Most upvoted comments

FYI, I will help take a look at it since I implemented most of unpack things. I will get into this early next week.

Sounds good @hanhanW! I don’t expect to have time for this before Wednesday, so you’ll beat me at it 😃.

Let me know if you need help.

FYI, I will help take a look at it since I implemented most of unpack things. I will get into this early next week.