iree: mhlo.dynamic_broadcast_in_dim was explicitly marked illegal

This is in the context of JAX+IREE for dynamic shapes. The code below tries to broadcast f32[w] to f32[w, w], but it fails with failed to legalize operation 'mhlo.dynamic_broadcast_in_dim' that was explicitly marked illegal.

Is the MHLO code that I am generating valid? Should I generate something else?

Repro:

from iree.compiler import compile_str

CODE = """
module @jit_f.3  {
  func public @main(%arg0: tensor<?xf32> loc(unknown)) -> tensor<?x?xf32> {
    %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32>
    %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %2 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %3 = "mhlo.concatenate"(%1, %2) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %4 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
    return %4 : tensor<?x?xf32>
  }
}
"""

compiled_flatbuffer = compile_str(CODE, target_backends=["dylib"], input_type="mhlo")

and the error:

$ python tests/iree_repro.py
Traceback (most recent call last):
  File "/Users/necula/Source/jax/tests/iree_repro.py", line 16, in <module>
    compiled_flatbuffer = compile_str(CODE, target_backends=["dylib"], input_type="mhlo")
  File "/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/core.py", line 262, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
  File "/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/binaries.py", line 201, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool ireec
Diagnostics:
<stdin>:8:10: error: failed to legalize operation 'mhlo.dynamic_broadcast_in_dim' that was explicitly marked illegal
    %4 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
         ^
<stdin>:2:1: error: conversion from source -> vm failed
module @jit_f.3  {
^


Invoked with:
 ireec /Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/ireec - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/necula/.pyenv/versions/jax39/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 15 (3 by maintainers)

Most upvoted comments

Thank you. At this point I am going to sign off for a vacation, but I will pick this up when I return.