import iree.compiler
CODE = """
#loc0 = loc(unknown)
module @jit_prim_fun.0 {
func.func public @main(%arg0: tensor<0x2xi16> loc(unknown), %arg1: tensor<i16> loc(unknown)) -> tensor<3x3xi16> {
%0 = "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 1]> : tensor<2xi64>, edge_padding_low = dense<[1, 0]> : tensor<2xi64>, interior_padding = dense<[1, 0]> : tensor<2xi64>} : (tensor<0x2xi16>, tensor<i16>) -> tensor<3x3xi16> loc(#loc1)
return %0 : tensor<3x3xi16> loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("jit(pad)/jit(main)/pad[padding_config=((1, 2, 1), (0, 1, 0))]"("/Users/phawkins/p/jax/tests/lax_test.py":1419:1))
"""
iree_binary = iree.compiler.compile_str(
CODE, target_backends=["dylib"], input_type="mhlo")
Traceback (most recent call last):
File "/Users/phawkins/p/jax/t.py", line 14, in <module>
iree_binary = iree.compiler.compile_str(
File "/Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/core.py", line 293, in compile_str
result = invoke_immediate(cl, immediate_input=input_bytes)
File "/Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/binaries.py", line 200, in invoke_immediate
raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile
Diagnostics:
/Users/phawkins/p/jax/tests/lax_test.py:1419:1: error: expected result type to be 'memref<0x2xi16, affine_map<(d0, d1) -> (0)>>' or a rank-reduced version. (mismatch of result layout)
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
^
/Users/phawkins/p/jax/tests/lax_test.py:1419:1: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
^
/Users/phawkins/p/jax/tests/lax_test.py:1419:1: error: failed to serialize executables
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
^
compilation failed
Invoked with:
iree-compile /Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=dylib --iree-mlir-to-vm-bytecode-module --iree-llvm-embedded-linker-path=/Users/phawkins/.pyenv/versions/py310/lib/python3.10/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false
First thing: I misread your comment first to imply there was an HLO spec and got excited. But then I read it again…
So we just need to do this. I suspect that both of you are right: some simple high level patterns to replace them with constants will deaden most parts of the program. Then there are only a few things left…
Zero-sized shapes are legal according to the only HLO spec we have, namely XLA accepts them. I admit that is partially because I made XLA accept them years ago. We should make this explicit in the MHLO spec (either way).
My personal belief is that zero-sized shapes are very useful as a base-case; it is considerably more inconvenient to work around their absence at the Python level than it is to add a compiler pass that removes them. The latter is what XLA, being a statically shaped compiler, does: it has an early HLO pass that replaces all zero-sized operators with constants and then the need to handle zero-element shapes is limited to only a few places.
I have less opinion about dynamically-zero shapes: I can believe they may be significant harder to deal with.