iree: mhlo.dynamic_broadcast_in_dim was explicitly marked illegal
This is in the context of JAX+IREE for dynamic shapes. The code below tries to broadcast f32[w]
to f32[w, w]
, but it fails with failed to legalize operation 'mhlo.dynamic_broadcast_in_dim' that was explicitly marked illegal
.
Is the MHLO code that I am generating valid? Should I generate something else?
Repro:
from iree.compiler import compile_str
CODE = """
module @jit_f.3 {
func public @main(%arg0: tensor<?xf32> loc(unknown)) -> tensor<?x?xf32> {
%0 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
%1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
%2 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
%3 = "mhlo.concatenate"(%1, %2) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%4 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %4 : tensor<?x?xf32>
}
}
"""
compiled_flatbuffer = compile_str(CODE, target_backends=["dylib"], input_type="mhlo")
and the error:
$ python tests/iree_repro.py
Traceback (most recent call last):
File "/Users/necula/Source/jax/tests/iree_repro.py", line 16, in <module>
compiled_flatbuffer = compile_str(CODE, target_backends=["dylib"], input_type="mhlo")
File "/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/core.py", line 262, in compile_str
result = invoke_immediate(cl, immediate_input=input_bytes)
File "/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/binaries.py", line 201, in invoke_immediate
raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool ireec
Diagnostics:
<stdin>:8:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim' that was explicitly marked illegal
%4 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
^
<stdin>:2:1: error: conversion from source -> vm failed
module @jit_f.3 {
^
Invoked with:
ireec /Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/ireec - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 15 (3 by maintainers)
Thank you. At this point I am going to sign off for a vacation, but I will pick this up when I return.