mpi4jax: Segmentation fault on GPU to GPU communication

I get a segmentation fault with some MPI primitives using cuda-enabled mpi. The issue seems to appear when xla is not initialized, as the error disappears if memory is allocated on the GPU before mpi4jax is imported.

Run command used (with gpus on two separate nodes):

MPI4JAX_USE_CUDA_MPI=1 mpiexec -npernode 1 python run.py

Contents of run.py to reproduce the error:

import mpi4jax
from mpi4py import MPI

comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
rank = comm.Get_rank()

root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)

Error message:

--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec noticed that process rank 0 with PID 0 on node 0 exited on signal 11 (Segmentation fault).
--------------------------------------------------------------------------

Workarounds

1: Using MPI4JAX_USE_CUDA_MPI=0.

2: Importing jax and creating a DeviceArray fixes the problem, but it has to be added before import mpi4jax. As an example, inserting this in the beginning of run.py works

import jax.numpy as jnp
jnp.array(3.)

3: Some primitives (only tested mpi4jax.allreduce) works just fine out of the box. This following piece of code doesn’t crash before the bcast

rank_sum, _ = mpi4jax.allreduce(rank, op=MPI.SUM, comm=comm)
print(rank, rank_sum)

root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)

Versions

Python 3.8.6 OpenMPI 4.0.5-gcccuda-2020b CUDA 11.1.1.GCC-10.2.0 mpi4py 3.1.1 mpi4jax 0.3.2 jax 0.2.21 jaxlib 0.1.71[cuda111]

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 27 (8 by maintainers)

Most upvoted comments

Thanks for the help, though, and for the nice library!