jax: Crash in NCCL when running with distributed GPU
Description
(could be a cluster issue, but unclear. This is on an HPC where I don’t have sudo, and there are a lot of CUDA installations floating around, so I haven’t ruled out that it’s not some dumb mismatch there, but I’ve done my best to track that down.)
This code works fine on (1) single node with multiple gpus, or (2) pair of machines with single gpu
For multi-host, multi-gpu it hangs regardless of one gpu per node or per process.
import os
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import multihost_utils
from jax.experimental.pjit import pjit
from jax.experimental import maps
from jax.experimental import PartitionSpec
# this is in slurm fwiw
jax.distributed.initialize(local_device_ids=[int(x) for x in os.environ.get('CUDA_VISIBLE_DEVICES', '0').split(',')])
# if slurm and one gpu per process then do
# jax.distributed.initialize()
print(jax.devices())
print(jax.process_count())
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
print(r)
With one process per gpu and NCCL_DEBUG=TRACE I get:
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=2, slice_index=1), StreamExecutorGpuDevice(id=3, process_index=3, slice_index=1)]
4
sphinx6:70775:70775 [0] NCCL INFO Bootstrap : Using enp49s0f0:172.24.67.98<0>
sphinx6:70775:70775 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
sphinx6:70775:70775 [0] NCCL INFO cudaDriverVersion 11070
NCCL version 2.13.4+cudaCUDA_MAJOR.CUDA_MINOR
sphinx6:70775:70775 [0] NCCL INFO NET/IB : Using [0]mlx5_0:1/IB [1]mlx5_1:1/IB [2]mlx5_2:1/IB [3]mlx5_3:1/IB [RO]; OOB enp49s0f0:172.24.67.98<0>
sphinx6:70775:70775 [0] NCCL INFO Using network IB
sphinx6:70775:70775 [0] NCCL INFO Setting affinity for GPU 0 to ff00,00000000,00000000,00000000,0000ff00
sphinx6:70775:70775 [0] NCCL INFO Channel 00/02 : 0 1 2 3
sphinx6:70775:70775 [0] NCCL INFO Channel 01/02 : 0 1 2 3
sphinx6:70775:70775 [0] NCCL INFO Trees [0] 1/2/-1->0->-1 [1] 1/-1/-1->0->2
sphinx6:70775:70775 [0] NCCL INFO Channel 00/0 : 3[87000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 01/0 : 3[87000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 00 : 0[a000] -> 1[47000] via P2P/IPC/read
sphinx6:70775:70775 [0] NCCL INFO Channel 01 : 0[a000] -> 1[47000] via P2P/IPC/read
sphinx6:70775:70775 [0] NCCL INFO Connected all rings
sphinx6:70775:70775 [0] NCCL INFO Channel 00/0 : 2[47000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 01/0 : 2[47000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 00/0 : 0[a000] -> 2[47000] [send] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 01/0 : 0[a000] -> 2[47000] [send] via NET/IB/0/GDRDMA
sphinx6:70775:71035 [0] external/nccl_archive/src/misc/ibvwrap.cc:283 NCCL WARN Call to ibv_reg_mr_iova2 failed with error Cannot allocate memory
sphinx6:70775:71035 [0] NCCL INFO external/nccl_archive/src/transport/net_ib.cc:871 -> 2
sphinx6:70775:71035 [0] NCCL INFO bazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs/net.h:28 -> 2
sphinx6:70775:71035 [0] NCCL INFO external/nccl_archive/src/transport/net.cc:717 -> 2
sphinx6:70775:71035 [0] NCCL INFO external/nccl_archive/src/proxy.cc:968 -> 2
sphinx6:70775:71035 [0] external/nccl_archive/src/proxy.cc:1097 NCCL WARN [Proxy Service 0] Failed to execute operation Connect from rank 0, retcode 2
sphinx6:70775:70775 [0] external/nccl_archive/src/misc/socket.cc:521 NCCL WARN Net : Connection closed by remote peer sphinx6.stanford.edu<50233>
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/misc/socket.cc:529 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/misc/socket.cc:541 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/proxy.cc:859 -> 6
sphinx6:70775:70775 [0] external/nccl_archive/src/proxy.cc:862 NCCL WARN Proxy Call to rank 0 failed (Connect)
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/transport/net.cc:319 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/transport.cc:134 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:784 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:1045 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:1091 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:1106 -> 6
2022-12-07 22:00:53.711817: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:285: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: remote process exited or there was a network error
Traceback (most recent call last):
File "/juice2/u/dlwh/src/levanter/simple_distributed.py", line 19, in <module>
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/_src/api.py", line 2256, in cache_miss
out = map_bind_continuation(execute(*tracers))
File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2071, in __call__
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:285: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: remote process exited or there was a network error
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/juice2/u/dlwh/src/levanter/simple_distributed.py", line 19, in <module>
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:285: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: remote process exited or there was a network error
What jax/jaxlib version are you using?
0.3.25
Which accelerator(s) are you using?
GPU (A100s)
Additional system info
Python 3.10.8, Ubuntu, Cuda 11.7
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04 Driver Version: 515.43.04 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:0A:00.0 Off | 0 |
| N/A 28C P0 59W / 350W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM... On | 00000000:47:00.0 Off | 0 |
| N/A 28C P0 59W / 350W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
which nvcc gives /usr/local/cuda-11.7/bin/nvcc
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 1
- Comments: 26 (24 by maintainers)
pretty sure this is a nccl or system configuration problem now. If I set
NCCL_IB_DISABLE=1 NCCL_SOCKET_IFNAME=enp
the script works. Without both, it doesn’t work.without NCCL_IB_DISABLE=1 I get these warnings and then a hang:
without the NCCL_SOCKET_IFNAME=enp it just hangs.
Of interest:
With
NCCL_IB_DISABLE=1 NCCL_SOCKET_IFNAME=enp
I seewith just NCCL_SOCKET_IFNAME=enp I see:
with
NCCL_IB_DISABLE=1
i see: