iree: Lowering failure for mhlo.pad of empty tensor

import iree.compiler

CODE = """
#loc0 = loc(unknown)
module @jit_prim_fun.0 {
  func.func public @main(%arg0: tensor<0x2xi16> loc(unknown), %arg1: tensor<i16> loc(unknown)) -> tensor<3x3xi16> {
    %0 = "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 1]> : tensor<2xi64>, edge_padding_low = dense<[1, 0]> : tensor<2xi64>, interior_padding = dense<[1, 0]> : tensor<2xi64>} : (tensor<0x2xi16>, tensor<i16>) -> tensor<3x3xi16> loc(#loc1)
    return %0 : tensor<3x3xi16> loc(#loc0)
  } loc(#loc0)
} loc(#loc0)
#loc1 = loc("jit(pad)/jit(main)/pad[padding_config=((1, 2, 1), (0, 1, 0))]"("/Users/phawkins/p/jax/tests/lax_test.py":1419:1))
"""

iree_binary = iree.compiler.compile_str(
    CODE, target_backends=["dylib"], input_type="mhlo")

Traceback (most recent call last):
  File "/Users/phawkins/p/jax/t.py", line 14, in <module>
    iree_binary = iree.compiler.compile_str(
  File "/Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/core.py", line 293, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
  File "/Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/binaries.py", line 200, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile
Diagnostics:
/Users/phawkins/p/jax/tests/lax_test.py:1419:1: error: expected result type to be 'memref<0x2xi16, affine_map<(d0, d1) -> (0)>>' or a rank-reduced version. (mismatch of result layout)
    fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
^
/Users/phawkins/p/jax/tests/lax_test.py:1419:1: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
    fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
^
/Users/phawkins/p/jax/tests/lax_test.py:1419:1: error: failed to serialize executables
    fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
^
compilation failed


Invoked with:
 iree-compile /Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 27 (15 by maintainers)

Most upvoted comments

First thing: I misread your comment first to imply there was an HLO spec and got excited. But then I read it again…

So we just need to do this. I suspect that both of you are right: some simple high level patterns to replace them with constants will deaden most parts of the program. Then there are only a few things left…

Zero-sized shapes are legal according to the only HLO spec we have, namely XLA accepts them. I admit that is partially because I made XLA accept them years ago. We should make this explicit in the MHLO spec (either way).

My personal belief is that zero-sized shapes are very useful as a base-case; it is considerably more inconvenient to work around their absence at the Python level than it is to add a compiler pass that removes them. The latter is what XLA, being a statically shaped compiler, does: it has an early HLO pass that replaces all zero-sized operators with constants and then the need to handle zero-element shapes is limited to only a few places.

I have less opinion about dynamically-zero shapes: I can believe they may be significant harder to deal with.