iree: Failed to legalize operation 'mhlo.scatter'
thanks for everyone’s effort so far in getting iree to utilise the M1 GPU via Vulkan. 🙂
Describe the bug
from iree import compiler
CODE = """
#loc0 = loc(unknown)
module @jit_prim_fun.12 {
func.func public @main(%arg0: tensor<1x1xi32> loc(unknown), %arg1: tensor<1xi32> loc(unknown), %arg2: tensor<1xi32> loc(unknown)) -> tensor<1x1xi32> {
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<i32> loc(unknown), %arg4: tensor<i32> loc(unknown)):
"mhlo.return"(%arg4) : (tensor<i32>) -> () loc(#loc1)
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<1x1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1xi32> loc(#loc1)
return %0 : tensor<1x1xi32> loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("jit(scatter)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"("/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py":926:1))
"""
extra_args = ["--mlir-print-ir-after-failure"]
# also reproduces with backend "dylib", the default backend for macOS
iree_binary = compiler.compile_str(
CODE, target_backends=["vulkan"], input_type="mhlo", extra_args=extra_args)
Error invoking IREE compiler tool iree-compile
Diagnostics:
/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py:926:1: error: failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
// -----// IR Dump After ConvertMHLOToLinalgExt Failed //----- //
func.func public @main(%arg0: tensor<1x1xi32> loc(unknown), %arg1: tensor<1xi32> loc(unknown), %arg2: tensor<1xi32> loc(unknown)) -> tensor<1x1xi32> {
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<i32> loc(unknown), %arg4: tensor<i32> loc(unknown)):
"mhlo.return"(%arg4) : (tensor<i32>) -> () loc("jit(scatter)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"("/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py":926:1))
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<1x1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1xi32> loc("jit(scatter)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"("/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/transformers/models/bart/modeling_flax_bart.py":926:1))
return %0 : tensor<1x1xi32> loc(unknown)
} loc(unknown)
Invoked with:
iree-compile /Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/birch/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-print-ir-after-failure
to do a more realistic macOS/MoltenVK integration, I’d recommend the following extra_args
too:
https://github.com/nod-ai/SHARK/blob/1186d7c58e6046aea6a6115c608dbd77728e7aca/shark/iree_utils.py#L93-L96
extra_args += [
"--iree-llvm-target-triple=arm64-apple-darwin21.5.0",
"--iree-flow-demote-i64-to-i32",
"--iree-vulkan-target-triple=m1-moltenvk-macos",
"--iree-llvm-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false"]
the problem reproduces either way.
reproducing in integration context
Using very-new jax+jaxlib and iree, I invoked a known-good jax model dalle-playground
, with environment variables:
JAX_PLATFORMS=iree
JAX_IREE_BACKEND=vulkan
and extra_args
:
extra_args += [
"--iree-llvm-target-triple=arm64-apple-darwin21.5.0",
"--iree-flow-demote-i64-to-i32",
"--iree-vulkan-target-triple=m1-moltenvk-macos",
"--iree-llvm-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false"]
Then follow these steps to install & run dalle-playground.
Typical output looks like this.
Additional context
I reproduced this against:
jax+jaxlib commit https://github.com/google/jax/commit/345cc19949273cc414d94e6f13d0620b780af465 (which I built myself locally)
iree compiler+runtime release candidate-20220606.161 (https://github.com/google/iree/commit/f90a3a13807b312407d08a20965a6b5774e3371b)
dalle-playground commit https://github.com/saharmor/dalle-playground/commit/5a7be250e5c377b18a3f9f07631d177bcf91c934
MacBook Pro (14-inch, 2021) M1 Max (2 efficiency cores + 8 perf cores + 64GB unified memory + 32 GPU cores)
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 16 (11 by maintainers)
Same works at the tag.
If in src/iree/compiler/InputConversion/MHLO/Passes.cpp I switch
passManager.addNestedPassfunc::FuncOp(createConvertMHLOToLinalgExtPass()); passManager.addNestedPassfunc::FuncOp(createMHLOToLinalgOnTensorsPass());
around, then it passes.
Seems due to
in ConvertMHLOToLinalgExt.cpp and so any Scatter that can’t be lowered to LinalgExt version would fail. Assigning to @hanhanW who added this for comment.