netket: MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated'

Hello, I’m aware this might be a shot into the blue but let’s see where it goes: I have a script that trains a custom model on the Toric Code. This script works and runs without any problems using MPI with one or two processes. However, when running it with with “-np 4” I get the following error: MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated I’ve been able to narrow down the source. After reducing my model complexity the error disappeared and I was able to run it on e.g. 4 processes. Moreover, vqs.expect(…) calls don’t cause the error, but vqs.expect_and_forces(…) does. So it seems that the last line in expect_forces.py return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state contains the MPI Allreduce call that throws the error. This is plausible to me because the error message suggests buffer issues and increasing the number of parameters in my network directly affects the size of Ō_grad which gets reduced.

Now, is this something I can fix or does it need to be fixed within Netket, MPI4Jax or within my MPI installation? I’m not sure where to go from here.

Thanks!

Edit: I’m using the following versions: NetKet 3.7 mpi4py 3.1.4 mpi4jax 0.3.14.post1 jax 0.4.6 jaxlib 0.4.6+cuda11.cudnn82

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 19 (9 by maintainers)

Most upvoted comments

@MandMarc , @inailuig identified the problem and is working on a fix. Indeed it was due to token mismanagement in netket.

the fix will take a bit to land but we’ll get there…