mpi4jax: Incorrect AlltoAll behavior in jitted function
First of all, thanks for this really great project! I am very much looking forward to using it for building a jax-based distributed cosmological N-body simulation (https://github.com/DifferentiableUniverseInitiative/JaxPM). I’m encountering a very strange issue though when trying to JIT functions, which I can boil down to the minimum working example I present bellow involving preforming an AlltoAll collective to transpose an array:
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
# Create communicators
world = MPI.COMM_WORLD
rank = world.Get_rank()
size = world.Get_size()
# Gobal size of 2D array
shape = [256, 256]
# Create a slice of the 2D array on each process
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
local_slice = jax.random.normal(key, [shape[0]//size, shape[1]])
@jax.jit
def transpose(arr):
# Split the last dimension in how many processes we have
arr = arr.reshape([shape[0]//size, size, shape[1]//size])
arr = arr.transpose([1, 0, 2]) # Move the number of slices to first dimension
arr, token = mpi4jax.alltoall(arr, comm=world)
# Merge the collected slices of the first dimension
arr = arr.reshape([shape[0], shape[1]//size])
arr = arr.transpose([1,0]) # Moves the unsliced dimension to the end
return arr, token
# Perform the transpose
transposed_slice, token = transpose(local_slice)
# Let's make sure we have waited for all computations to be done
token = mpi4jax.barrier(token=token)
# Let's check that the results are OK
transposed_array, token = mpi4jax.allgather(transposed_slice, comm=world, token=token)
transposed_array = transposed_array.reshape(shape)
total_array, token = mpi4jax.allgather(local_slice, comm=world, token=token)
total_array = total_array.reshape(shape)
print("On rank %d, here is the result: "%rank,
jnp.abs(total_array.transpose([1,0]) - transposed_array).mean())
If the transpose function is jitted, the output on a 4 process mesh is:
$ srun python test_mpi4jax.py
On rank 0, here is the result: 1.1207631
On rank 1, here is the result: 1.1207631
On rank 3, here is the result: 1.1207631
On rank 2, here is the result: 1.1207631
If I remove the jitting on that function, I get the expected behavior showing that the transpose is correctly computed:
$ srun python test_mpi4jax.py
On rank 0, here is the result: 0.0
On rank 1, here is the result: 0.0
On rank 2, here is the result: 0.0
On rank 3, here is the result: 0.0
This example was run using jax v0.3.23, and mpi4jax v0.3.10.post6, and I also tried to revert back to a jax version from this summer, without much difference. I only tried on GPU but it appears to be unaffected by the use of CUDA-aware MPI or not. But enabling debug mode I can see that the communications do occur as expected.
Very interestingly, I noticed that if I tweak the code in the following way:
@jax.jit
def transpose(arr):
arr = arr.reshape([shape[0]//size, size, shape[1]//size])
arr = jnp.einsum('ij,xjy->ixy', jnp.eye(size), arr) # instead of arr.transpose([1, 0, 2])
arr, token = mpi4jax.alltoall(arr, comm=world)
return arr, token
arr, token = transpose(local_slice)
arr = arr.reshape([shape[0], shape[1]//size])
transposed_slice = arr.transpose([1,0])
i.e. replacing the first jnp.tranpose by an equivalent but more complex op involving an einsum gives me the correct results. So I don’t think there is something intrinsically wrong with mpi4jax, but something weird going on with the input and output buffers of the collective op, for lax particular ops.
This makes me think that the physical memory layout of the buffer at the output of the transpose op doesn’t actually correspond to its theoretical shape, which may be some optimization done by XLA to prevent moving the memory layout back and forth between operations.
I’m not sure how to debug this further, any help or insight would be much appreciated! I’ll continue to poke around anyway because I really really want to use mpi4jax for my application ^^
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 32 (6 by maintainers)
Commits related to this issue
- add test for issue #176 — committed to mpi4jax/mpi4jax by dionhaefner 2 years ago
I created an upstream issue, let’s see what comes out of this.
Works like a charm! Awesome thanks so much!
Tested on the example at the top of this thread using JAX v0.3.25 on 8 GPUs on 2 nodes
@EiffL or @PhilipVinc could you confirm that this is fixed on
main?Sounds like a fix is going to go live soon. As soon as it lands I will probably take the opportunity and migrate everything to
custom_call.Indeed you’re right that the slicing in
alltoallcares about the ordering… But then you don’t need to apply your decorator to all the primitives we export.I’m unsure about the ordering: I just tried and you can access an object in the encoding phase that gives you the layout of the input
the problem is that I can’t use that to create the output layout, as
fails…
Correct me if I’m wrong but I think there is no exposed way to manipulate or even retrieve buffer layouts (these are determined at the XLA level, for example when XLA fuses subsequent transposes). Hence the need for
layoutparameters in custom call wrappers that inform XLA on what kind of layout we expect on inputs and outputs.I don’t know if there’s an input that says “I don’t care about the layout as long as input and output layouts are the same”.
Thanks for looking into this. Your code looks good to me, this is what I would have tried, too. I presume this is a bug in JAX / XLA - we encounter these regularly since we seem to be the only major users of the token mechanism.
Do you have a good standalone example that shows the problem when using tokens with
CustomCallWithLayout? In this case we should raise the issue. I wouldn’t bet on it getting fixed because the custom call interface is deprecated, but if it’s an easy fix they might do it.Sound like the long-term solution is to switch to the new interface. In the meantime we could add something like
arr = jnp.ascontiguousarray(arr)to all handlers. Could you try whether that fixes it?Layout is not checked, everything is assumed to be C contiguous I think. So that’s probably the source of the error.
It’s not that strange, because the compiler is much more aggressive on GPU. So it actually hints towards what you suspected: The transposes are not actually manifested on GPU. I’ll have to dig into this and see how the transpose information is passed down in this case.