iree: IREE model gives wrong result
What happened?
I have a jax.numpy
function computing mean on columns on a 2D tensor, but IREE does not give the right results.
building this func
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
running this func on Traced<ShapedArray(float32[2,8])>with<DynamicJaxprTrace(level=0/1)>
running this func on Traced<ShapedArray(float32[2,8])>with<DynamicJaxprTrace(level=1/0)>
[[ 5. 6. 7. 8. 9. 10. 11. 12.]]
[[ 5. inf inf inf 9. inf inf inf]]
Traceback (most recent call last):
File "/home/ubuntu/shoreline/python/ceva.py", line 75, in <module>
assert jnp.allclose(jax_result, iree_result)
AssertionError
The output comes from the below script.
Steps to reproduce your issue
Run the script below. Assertion should pass. It does not, however.
import jax
import iree.compiler
import iree.compiler.tools.xla
import iree.runtime as ireert
import numpy as np
import jax.numpy as jnp
def build_group_avg():
print("building this func")
def _f(x):
return jnp.nanmean(x, axis=0, keepdims=True)
def __inner__(data):
print(f"running this func on {data}")
# vmap will apply the mean to each column
f = jax.vmap(_f, in_axes=1, out_axes=1)
return f(data)
return __inner__
def aot(function, static_argnums, *args, **options):
"""Traces and compiles a function, flattening the input args.
This is intended to be a lower-level interface for compiling a JAX function to
IREE without setting up the runtime bindings to use it within Python. A common
usecase for this is compiling to Android (and similar targets).
Args:
function: The function to compile.
args: The inputs to trace and compile the function for.
**kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions
"""
xla_comp = jax.xla_computation(function, static_argnums=static_argnums)(*args)
hlo_proto = xla_comp.as_serialized_hlo_module_proto()
return iree.compiler.tools.xla.compile_str(hlo_proto, input_type=iree.compiler.InputType.XLA, **options)
_DYLIB_CONFIG, _VM_INSTANCE = None, None
def load_iree_binary(pipeline_bytes: bytes):
"""Loads a compiled pipeline using IREE"""
global _DYLIB_CONFIG, _VM_INSTANCE
if not _DYLIB_CONFIG:
# lazy initialization – ensures the config is created in the context of the calling process
_DYLIB_CONFIG = ireert.Config("local-task")
if not _VM_INSTANCE:
_VM_INSTANCE = ireert.VmInstance()
vm_module = ireert.VmModule.from_flatbuffer(_VM_INSTANCE, pipeline_bytes)
module_object = ireert.load_vm_module(vm_module, _DYLIB_CONFIG)
return module_object["main"] # the compiled function to be ran
ii = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]]).astype(np.float32)
f = build_group_avg()
f_jitted = jax.jit(f)
jax_result = np.array(f_jitted(ii))
compiled = aot(f, [], ii, target_backends=["llvm-cpu"])
ireefunc = load_iree_binary(compiled)
iree_result = np.array(ireefunc(ii))
print(jax_result)
print(iree_result)
assert jnp.allclose(jax_result, iree_result)
What component(s) does this issue relate to?
Frontends, Python, Runtime
Version information
jax==0.3.17
jaxlib==0.3.15
-f https://github.com/iree-org/iree/releases/expanded_assets/candidate-20220902.254
--extra-index-url https://github.com/iree-org/iree/releases/expanded_assets/candidate-20220902.254
iree-compiler
iree-runtime
iree-tools-tf
iree-tools-tflite
iree-tools-xla
Additional context
Ran on Ubuntu with x86_64 cpu.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 16 (10 by maintainers)
Thanks, we are running behind in completing some of the processes wrt new release. For now please feel free to use it. Hope to get new stable release out soon but hitting some procedural issues.