iree: Failed to legalize operation 'mhlo.scatter'

thanks for everyone’s effort so far in getting iree to utilise the M1 GPU via Vulkan. 🙂

Describe the bug

from iree import compiler

CODE = """
#loc0 = loc(unknown)
module @jit_prim_fun.12 {
  func.func public @main(%arg0: tensor<1x1xi32> loc(unknown), %arg1: tensor<1xi32> loc(unknown), %arg2: tensor<1xi32> loc(unknown)) -> tensor<1x1xi32> {
    %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
    ^bb0(%arg3: tensor<i32> loc(unknown), %arg4: tensor<i32> loc(unknown)):
      "mhlo.return"(%arg4) : (tensor<i32>) -> () loc(#loc1)
    }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<1x1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1xi32> loc(#loc1)
    return %0 : tensor<1x1xi32> loc(#loc0)
  } loc(#loc0)
  } loc(#loc0)
#loc1 = loc("jit(scatter)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"("/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py":926:1))
"""

extra_args = ["--mlir-print-ir-after-failure"]

# also reproduces with backend "dylib", the default backend for macOS
iree_binary = compiler.compile_str(
    CODE, target_backends=["vulkan"], input_type="mhlo", extra_args=extra_args)
Error invoking IREE compiler tool iree-compile
Diagnostics:
/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py:926:1: error: failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal
        input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
// -----// IR Dump After ConvertMHLOToLinalgExt Failed //----- //
func.func public @main(%arg0: tensor<1x1xi32> loc(unknown), %arg1: tensor<1xi32> loc(unknown), %arg2: tensor<1xi32> loc(unknown)) -> tensor<1x1xi32> {
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<i32> loc(unknown), %arg4: tensor<i32> loc(unknown)):
    "mhlo.return"(%arg4) : (tensor<i32>) -> () loc("jit(scatter)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"("/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py":926:1))
  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<1x1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1xi32> loc("jit(scatter)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"("/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py":926:1))
  return %0 : tensor<1x1xi32> loc(unknown)
} loc(unknown)
Invoked with:
 iree-compile /Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-print-ir-after-failure

to do a more realistic macOS/MoltenVK integration, I’d recommend the following extra_args too:
https://github.com/nod-ai/SHARK/blob/1186d7c58e6046aea6a6115c608dbd77728e7aca/shark/iree_utils.py#L93-L96

extra_args += [
"--iree-llvm-target-triple=arm64-apple-darwin21.5.0",
"--iree-flow-demote-i64-to-i32",
"--iree-vulkan-target-triple=m1-moltenvk-macos",
"--iree-llvm-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false"]

the problem reproduces either way.

reproducing in integration context

Using very-new jax+jaxlib and iree, I invoked a known-good jax model dalle-playground, with environment variables:

JAX_PLATFORMS=iree
JAX_IREE_BACKEND=vulkan

and extra_args:

extra_args += [
"--iree-llvm-target-triple=arm64-apple-darwin21.5.0",
"--iree-flow-demote-i64-to-i32",
"--iree-vulkan-target-triple=m1-moltenvk-macos",
"--iree-llvm-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false"]

Then follow these steps to install & run dalle-playground.
Typical output looks like this.

Additional context

I reproduced this against:
jax+jaxlib commit https://github.com/google/jax/commit/345cc19949273cc414d94e6f13d0620b780af465 (which I built myself locally)
iree compiler+runtime release candidate-20220606.161 (https://github.com/google/iree/commit/f90a3a13807b312407d08a20965a6b5774e3371b)
dalle-playground commit https://github.com/saharmor/dalle-playground/commit/5a7be250e5c377b18a3f9f07631d177bcf91c934
MacBook Pro (14-inch, 2021) M1 Max (2 efficiency cores + 8 perf cores + 64GB unified memory + 32 GPU cores)

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 16 (11 by maintainers)

Most upvoted comments

Same works at the tag.

If in src/iree/compiler/InputConversion/MHLO/Passes.cpp I switch

passManager.addNestedPassfunc::FuncOp(createConvertMHLOToLinalgExtPass()); passManager.addNestedPassfunc::FuncOp(createMHLOToLinalgOnTensorsPass());

around, then it passes.

Seems due to

target.addIllegalOp<mhlo::SortOp, mhlo::ScatterOp, mhlo::FftOp,
                        mhlo::ReverseOp>();

in ConvertMHLOToLinalgExt.cpp and so any Scatter that can’t be lowered to LinalgExt version would fail. Assigning to @hanhanW who added this for comment.