jax: Regression: CustomCall with side effect gets optimized out
Since JAX 0.3.15, our MPI CustomCalls are getting optimized out even though they are declared as having side effects (has_side_effect=True
).
Example:
import os
os.environ["MPI4JAX_DEBUG"] = "1"
import jax
import jax.numpy as jnp
from mpi4py import MPI
from mpi4jax import send
comm = MPI.COMM_WORLD
assert comm.Get_size() == 1
@jax.jit
def send_jit(x):
send(x, dest=0, comm=comm)
print("send without JIT")
send(jnp.ones(10), dest=0, comm=comm)
print("send with JIT")
send_jit(jnp.ones(10))
With jax==0.3.15
, the second call prints no debug information, indicating that send
is being optimized out:
send without JIT
r0 | ie1Jnxae | MPI_Send -> 0 with tag 0 and 10 items
r0 | ie1Jnxae | MPI_Send done with code 0 (4.43e-05s)
send with JIT
With jax==0.3.14
I get the expected output:
send without JIT
r0 | NCAn8Kbh | MPI_Send -> 0 with tag 0 and 10 items
r0 | NCAn8Kbh | MPI_Send done with code 0 (3.68e-05s)
send with JIT
r0 | Cn7a5QRb | MPI_Send -> 0 with tag 0 and 10 items
r0 | Cn7a5QRb | MPI_Send done with code 0 (5.04e-06s)
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 17 (9 by maintainers)
@sharadmv I think the bug is here: https://github.com/google/jax/blob/c08b4ee6d9a13787c16bcf867d04e04d81fca063/jax/_src/dispatch.py#L329
note that
jaxpr.outvars
is empty, and a custom call with a side effect isn’t considered to have ajaxpr
effect. I suspect we need to treat effectful custom calls as effectful forjaxpr
purposes as well?