jax: Crash in NCCL when running with distributed GPU

Description

(could be a cluster issue, but unclear. This is on an HPC where I don’t have sudo, and there are a lot of CUDA installations floating around, so I haven’t ruled out that it’s not some dumb mismatch there, but I’ve done my best to track that down.)

This code works fine on (1) single node with multiple gpus, or (2) pair of machines with single gpu

For multi-host, multi-gpu it hangs regardless of one gpu per node or per process.

import os

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import multihost_utils
from jax.experimental.pjit import pjit 
from jax.experimental import maps
from jax.experimental import PartitionSpec

# this is in slurm fwiw
jax.distributed.initialize(local_device_ids=[int(x) for x in os.environ.get('CUDA_VISIBLE_DEVICES', '0').split(',')])
# if slurm and one gpu per process then do
# jax.distributed.initialize()

print(jax.devices())
print(jax.process_count())

xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
print(r)

With one process per gpu and NCCL_DEBUG=TRACE I get:

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=1, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=2, slice_index=1), StreamExecutorGpuDevice(id=3, process_index=3, slice_index=1)]
4
sphinx6:70775:70775 [0] NCCL INFO Bootstrap : Using enp49s0f0:172.24.67.98<0>
sphinx6:70775:70775 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
sphinx6:70775:70775 [0] NCCL INFO cudaDriverVersion 11070
NCCL version 2.13.4+cudaCUDA_MAJOR.CUDA_MINOR
sphinx6:70775:70775 [0] NCCL INFO NET/IB : Using [0]mlx5_0:1/IB [1]mlx5_1:1/IB [2]mlx5_2:1/IB [3]mlx5_3:1/IB [RO]; OOB enp49s0f0:172.24.67.98<0>
sphinx6:70775:70775 [0] NCCL INFO Using network IB
sphinx6:70775:70775 [0] NCCL INFO Setting affinity for GPU 0 to ff00,00000000,00000000,00000000,0000ff00
sphinx6:70775:70775 [0] NCCL INFO Channel 00/02 :    0   1   2   3
sphinx6:70775:70775 [0] NCCL INFO Channel 01/02 :    0   1   2   3
sphinx6:70775:70775 [0] NCCL INFO Trees [0] 1/2/-1->0->-1 [1] 1/-1/-1->0->2
sphinx6:70775:70775 [0] NCCL INFO Channel 00/0 : 3[87000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 01/0 : 3[87000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 00 : 0[a000] -> 1[47000] via P2P/IPC/read
sphinx6:70775:70775 [0] NCCL INFO Channel 01 : 0[a000] -> 1[47000] via P2P/IPC/read
sphinx6:70775:70775 [0] NCCL INFO Connected all rings
sphinx6:70775:70775 [0] NCCL INFO Channel 00/0 : 2[47000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 01/0 : 2[47000] -> 0[a000] [receive] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 00/0 : 0[a000] -> 2[47000] [send] via NET/IB/0/GDRDMA
sphinx6:70775:70775 [0] NCCL INFO Channel 01/0 : 0[a000] -> 2[47000] [send] via NET/IB/0/GDRDMA

sphinx6:70775:71035 [0] external/nccl_archive/src/misc/ibvwrap.cc:283 NCCL WARN Call to ibv_reg_mr_iova2 failed with error Cannot allocate memory
sphinx6:70775:71035 [0] NCCL INFO external/nccl_archive/src/transport/net_ib.cc:871 -> 2
sphinx6:70775:71035 [0] NCCL INFO bazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs/net.h:28 -> 2
sphinx6:70775:71035 [0] NCCL INFO external/nccl_archive/src/transport/net.cc:717 -> 2
sphinx6:70775:71035 [0] NCCL INFO external/nccl_archive/src/proxy.cc:968 -> 2

sphinx6:70775:71035 [0] external/nccl_archive/src/proxy.cc:1097 NCCL WARN [Proxy Service 0] Failed to execute operation Connect from rank 0, retcode 2

sphinx6:70775:70775 [0] external/nccl_archive/src/misc/socket.cc:521 NCCL WARN Net : Connection closed by remote peer sphinx6.stanford.edu<50233>
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/misc/socket.cc:529 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/misc/socket.cc:541 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/proxy.cc:859 -> 6

sphinx6:70775:70775 [0] external/nccl_archive/src/proxy.cc:862 NCCL WARN Proxy Call to rank 0 failed (Connect)
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/transport/net.cc:319 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/transport.cc:134 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:784 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:1045 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:1091 -> 6
sphinx6:70775:70775 [0] NCCL INFO external/nccl_archive/src/init.cc:1106 -> 6
2022-12-07 22:00:53.711817: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:285: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: remote process exited or there was a network error
Traceback (most recent call last):
  File "/juice2/u/dlwh/src/levanter/simple_distributed.py", line 19, in <module>
    r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
  File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/_src/api.py", line 2256, in cache_miss
    out = map_bind_continuation(execute(*tracers))
  File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/nlp/scr/dlwh/miniconda3/envs/levanter/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2071, in __call__
    out_bufs = self.xla_executable.execute_sharded_on_local_devices(
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:285: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: remote process exited or there was a network error

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/juice2/u/dlwh/src/levanter/simple_distributed.py", line 19, in <module>
    r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:285: NCCL operation ncclCommInitRank(comm.get(), nranks, id, rank) failed: remote process exited or there was a network error

What jax/jaxlib version are you using?

0.3.25

Which accelerator(s) are you using?

GPU (A100s)

Additional system info

Python 3.10.8, Ubuntu, Cuda 11.7

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04    Driver Version: 515.43.04    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:0A:00.0 Off |                    0 |
| N/A   28C    P0    59W / 350W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:47:00.0 Off |                    0 |
| N/A   28C    P0    59W / 350W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

which nvcc gives /usr/local/cuda-11.7/bin/nvcc

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 1
  • Comments: 26 (24 by maintainers)

Most upvoted comments

pretty sure this is a nccl or system configuration problem now. If I set NCCL_IB_DISABLE=1 NCCL_SOCKET_IFNAME=enp the script works. Without both, it doesn’t work.

without NCCL_IB_DISABLE=1 I get these warnings and then a hang:

sphinx1:2609339:2609476 [0] external/nccl_archive/src/misc/ibvwrap.cc:283 NCCL WARN Call to ibv_reg_mr_iova2 failed with error Cannot allocate memory
sphinx1:2609339:2609476 [0] NCCL INFO external/nccl_archive/src/transport/net_ib.cc:871 -> 2
sphinx1:2609339:2609476 [0] NCCL INFO bazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs/net.h:28 -> 2
sphinx1:2609339:2609476 [0] NCCL INFO external/nccl_archive/src/transport/net.cc:717 -> 2
sphinx1:2609339:2609476 [0] NCCL INFO external/nccl_archive/src/proxy.cc:968 -> 2
sphinx1:2609339:2609476 [0] NCCL INFO external/nccl_archive/src/proxy.cc:996 -> 2

sphinx1:2609339:2609476 [0] external/nccl_archive/src/proxy.cc:1097 NCCL WARN [Proxy Service 0] Failed to execute operation Connect from rank 0, retcode 2
sphinx1:2609339:2609476 [0] NCCL INFO external/nccl_archive/src/transport/net_ib.cc:690 -> 2
sphinx1:2609339:2609476 [0] NCCL INFO bazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/include_hdrs/net.h:27 -> 2
sphinx1:2609339:2609476 [0] NCCL INFO external/nccl_archive/src/transport/net.cc:643 -> 2
sphinx1:2609339:2609476 [0] NCCL INFO external/nccl_archive/src/proxy.cc:968 -> 2

sphinx1:2609339:2609476 [0] external/nccl_archive/src/proxy.cc:1097 NCCL WARN [Proxy Service 0] Failed to execute operation Connect from rank 0, retcode 2

sphinx1:2609339:2609443 [0] external/nccl_archive/src/misc/socket.cc:521 NCCL WARN Net : Connection closed by remote peer sphinx1.stanford.edu<53799>
sphinx1:2609339:2609443 [0] NCCL INFO external/nccl_archive/src/misc/socket.cc:529 -> 6
sphinx1:2609339:2609443 [0] NCCL INFO external/nccl_archive/src/misc/socket.cc:541 -> 6
sphinx1:2609339:2609443 [0] NCCL INFO external/nccl_archive/src/proxy.cc:859 -> 6

without the NCCL_SOCKET_IFNAME=enp it just hangs.

Of interest:

With NCCL_IB_DISABLE=1 NCCL_SOCKET_IFNAME=enp I see

sphinx1:2608631:2608743 [1] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
sphinx1:2608631:2608743 [1] NCCL INFO NET/Socket : Using [0]enp177s0f0:172.24.67.75<0>
sphinx1:2608631:2608743 [1] NCCL INFO Using network Socket
sphinx1:2608631:2608741 [0] NCCL INFO Using network Socket

with just NCCL_SOCKET_IFNAME=enp I see:

sphinx1:2609339:2609443 [0] NCCL INFO NET/IB : Using [0]mlx5_1:1/IB [1]mlx5_2:1/IB [2]mlx5_3:1/IB [3]mlx5_4:1/IB [RO]; OOB enp177s0f0:172.24.67.75<0>
sphinx1:2609339:2609443 [0] NCCL INFO Using network IB
sphinx1:2609339:2609445 [1] NCCL INFO Using network IB

with NCCL_IB_DISABLE=1 i see:

sphinx1:2608874:2608985 [1] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
sphinx1:2608874:2608985 [1] NCCL INFO NET/Socket : Using [0]enp177s0f0:172.24.67.75<0> [1]br-0a71013f336a:192.168.192.1<0> [2]br-8138d0f1a247:192.168.128.1<0> [3]br-b66a74608043:192.168.144.1<0>
sphinx1:2608874:2608985 [1] NCCL INFO Using network Socket
sphinx1:2608874:2608983 [0] NCCL INFO Using network Socket