jax: Regression: CustomCall with side effect gets optimized out

Since JAX 0.3.15, our MPI CustomCalls are getting optimized out even though they are declared as having side effects (has_side_effect=True).

Example:

import os
os.environ["MPI4JAX_DEBUG"] = "1"

import jax
import jax.numpy as jnp

from mpi4py import MPI
from mpi4jax import send

comm = MPI.COMM_WORLD
assert comm.Get_size() == 1

@jax.jit
def send_jit(x):
    send(x, dest=0, comm=comm)


print("send without JIT")
send(jnp.ones(10), dest=0, comm=comm)

print("send with JIT")
send_jit(jnp.ones(10))

With jax==0.3.15, the second call prints no debug information, indicating that send is being optimized out:

send without JIT
r0 | ie1Jnxae | MPI_Send -> 0 with tag 0 and 10 items
r0 | ie1Jnxae | MPI_Send done with code 0 (4.43e-05s)
send with JIT

With jax==0.3.14 I get the expected output:

send without JIT
r0 | NCAn8Kbh | MPI_Send -> 0 with tag 0 and 10 items
r0 | NCAn8Kbh | MPI_Send done with code 0 (3.68e-05s)
send with JIT
r0 | Cn7a5QRb | MPI_Send -> 0 with tag 0 and 10 items
r0 | Cn7a5QRb | MPI_Send done with code 0 (5.04e-06s)

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 17 (9 by maintainers)

Commits related to this issue

Most upvoted comments

@sharadmv I think the bug is here: https://github.com/google/jax/blob/c08b4ee6d9a13787c16bcf867d04e04d81fca063/jax/_src/dispatch.py#L329

note that jaxpr.outvars is empty, and a custom call with a side effect isn’t considered to have a jaxpr effect. I suspect we need to treat effectful custom calls as effectful for jaxpr purposes as well?