rmm: [BUG] Unexpected memory usage on GPU0

Describe the bug I tried to use RMM with PyTorch. I launch my task with torchrun and set the rmm.mr for each device at the very beginning.

torch.cuda.change_current_allocator(rmm_torch_allocator)

pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30,
)
device = (int(os.environ['LOCAL_RANK']))

rmm.mr.set_per_device_resource(device, pool)

But each process occupies a chunk of memory on GPU0 like image

Steps/Code to reproduce bug

Expected behavior I expected each process launched by torchrun only uses the memory on the GPU assigned by LOCAL_RANK

Environment details (please complete the following information): I’m using RMM v23.10.00 Here is the output of the print_env.sh

***GPU Information***
Thu Dec  7 17:18:59 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H800                    On  | 00000000:08:00.0 Off |                    0 |
| N/A   31C    P0              70W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H800                    On  | 00000000:7E:00.0 Off |                    0 |
| N/A   26C    P0              69W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H800                    On  | 00000000:A2:00.0 Off |                    0 |
| N/A   33C    P0              74W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H800                    On  | 00000000:C6:00.0 Off |                    0 |
| N/A   29C    P0              69W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H800                    On  | 00000001:09:00.0 Off |                    0 |
| N/A   26C    P0              71W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H800                    On  | 00000001:7F:00.0 Off |                    0 |
| N/A   30C    P0              71W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H800                    On  | 00000001:A3:00.0 Off |                    0 |
| N/A   30C    P0              71W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H800                    On  | 00000001:C7:00.0 Off |                    0 |
| N/A   35C    P0              73W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

***CPU***
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                192
On-line CPU(s) list:   0-183
Off-line CPU(s) list:  184-191
Thread(s) per core:    1
Core(s) per socket:    48
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 143
Model name:            Intel(R) Xeon(R) Platinum 8469C
Stepping:              8
CPU MHz:               3100.000
CPU max MHz:           3800.0000
CPU min MHz:           800.0000
BogoMIPS:              5200.00
Virtualization:        VT-x
L1d cache:             48K
L1i cache:             32K
L2 cache:              2048K
L3 cache:              99840K
NUMA node0 CPU(s):     0-47,96-143
NUMA node1 CPU(s):     48-95,144-191
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm uintr md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities

***CMake***
/usr/local/cmake-3.22.0-linux-x86_64/bin/cmake
cmake version 3.22.0

CMake suite maintained and supported by Kitware (kitware.com/cmake).

***g++***
/usr/bin/g++
g++ (GCC) 8.3.1 20190311 (Red Hat 8.3.1-3)
Copyright (C) 2018 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.


***nvcc***
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Fri_Sep__8_19:17:24_PDT_2023
Cuda compilation tools, release 12.3, V12.3.52
Build cuda_12.3.r12.3/compiler.33281558_0

***Python***
/opt/conda/envs/python3.8/bin/python
Python 3.8.13

***Environment Variables***
PATH                            : /opt/conda/envs/python3.8/bin:/opt/conda/condabin:PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/cmake-3.22.0-linux-x86_64/bin:/usr/local/ninja:/home/admin/.local/bin:PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/cmake-3.22.0-linux-x86_64/bin:/usr/local/ninja:/home/admin/.local/bin:/usr/local/sbin:PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/cmake-3.22.0-linux-x86_64/bin:/usr/local/ninja:/home/admin/.local/bin:/usr/X11R6/bin:/opt/satools
LD_LIBRARY_PATH                 : /usr/local/lib64:/lib64:/usr/local/gcc75/lib:/usr/local/gcc75/lib64::/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/lib64:/pu:/opt/taobao/java/jre/lib/amd64/server:/apsara/alicpp/built/gcc-4.9.2/glog-0.3.4/lib:/apsara/alicpp/built/gcc-4.9.2/gflags-2.1.2/lib:/apsara/alicpp/built/gcc-4.9.2/protobuf-2.4.1.ali/lib:/apsara/alicpp/built/gcc-4.9.2/odps-cryptography-1.0.0/lib:/apsara/alicpp/built/gcc-4.9.2/boost-1.58.0.fix.thread/lib:/apsara/alicpp/built/gcc-4.9.2/openssl-1.0.2a/lib:/apsara/alicpp/built/gcc-4.9.2/mysql-connector-c-6.1.6/lib:/apsara/alicpp/built/gcc-4.9.2/arrow-0.16.0/lib64:/apsara/alicpp/built/gcc-4.9.2/bzip2-1.0.6/lib64:/apsara/alicpp/built/gcc-4.9.2/zstd-1.4.4/lib:/apsara/alicpp/built/gcc-4.9.2/libevent-2.0.22.stable/lib64:/worker:/worker/lib:/opt/conda/envs/python3.8.13/lib:/usr/local/hadoop/hadoop/lib/native
NUMBAPRO_NVVM                   :
NUMBAPRO_LIBDEVICE              :
CONDA_PREFIX                    : /opt/conda/envs/python3.8
PYTHON_PATH                     :

***conda packages***
/opt/conda/condabin/conda
# packages in environment at /opt/conda/envs/python3.8:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
absl-py                   2.0.0                    pypi_0    pypi
acm-sdk-python            0.4.11                   pypi_0    pypi
ai-scheduler              0.2-89a0e10133697087a464f6bc434bfb9f8b639eb5          pypi_0    pypi
aliyun-python-sdk-core    2.14.0                   pypi_0    pypi
aliyun-python-sdk-kms     2.16.2                   pypi_0    pypi
annotated-types           0.6.0                    pypi_0    pypi
apex                      0.1                      pypi_0    pypi
astunparse                1.6.3                    pypi_0    pypi
async-timeout             4.0.3                    pypi_0    pypi
av                        8.1.0                    pypi_0    pypi
bitsandbytes              0.41.0                   pypi_0    pypi
ca-certificates           2023.08.22           h06a4308_0
cachetools                5.3.2                    pypi_0    pypi
certifi                   2023.11.17               pypi_0    pypi
cffi                      1.16.0                   pypi_0    pypi
chardet                   3.0.4                    pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
common-io                 0.8.3                    pypi_0    pypi
crcmod                    1.7                      pypi_0    pypi
cryptography              41.0.7                   pypi_0    pypi
cuda-python               11.8.2                   pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
cython                    3.0.6                    pypi_0    pypi
deepspeed                 0.8.2                    pypi_0    pypi
docopt                    0.6.2                    pypi_0    pypi
easydict                  1.11                     pypi_0    pypi
einops                    0.7.0                    pypi_0    pypi
filelock                  3.13.1                   pypi_0    pypi
flash-attn                2.2.1                    pypi_0    pypi
fsspec                    2023.12.0                pypi_0    pypi
future                    0.18.2                   pypi_0    pypi
google-auth               2.24.0                   pypi_0    pypi
google-auth-oauthlib      1.0.0                    pypi_0    pypi
gputil                    1.4.0                    pypi_0    pypi
grpcio                    1.59.3                   pypi_0    pypi
hdfs                      2.7.3                    pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
idna                      2.8                      pypi_0    pypi
importlib-metadata        6.8.0                    pypi_0    pypi
intel-openmp              2024.0.0                 pypi_0    pypi
jieba                     0.42.1                   pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
jmespath                  0.10.0                   pypi_0    pypi
joblib                    1.3.2                    pypi_0    pypi
kazoo                     2.9.0                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
kmontitor-client          0.0.0                    pypi_0    pypi
lake-py-lib               0.1.7.ziying             pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.1.0                hdf63c60_0
libstdcxx-ng              9.1.0                hdf63c60_0
lightning-utilities       0.10.0                   pypi_0    pypi
llvmlite                  0.41.1                   pypi_0    pypi
lmdb                      0.94                     pypi_0    pypi
lru-dict                  1.3.0                    pypi_0    pypi
magma-cuda121             2.6.1                         1    pytorch
markdown                  3.5.1                    pypi_0    pypi
markupsafe                2.1.3                    pypi_0    pypi
matplotlib                3.3.4                    pypi_0    pypi
mdl                       0.2                      pypi_0    pypi
mkl                       2024.0.0                 pypi_0    pypi
mkl-include               2024.0.0                 pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.3                  h7f8727e_2
nebula-mos-python-sdk     0.3.16                   pypi_0    pypi
nebula-py-pangu-early-test 0.0.41                   pypi_0    pypi
networkx                  3.1                      pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
numba                     0.58.1                   pypi_0    pypi
numpy                     1.24.4                   pypi_0    pypi
nvidia-ml-py3             7.352.0                  pypi_0    pypi
oauthlib                  3.2.2                    pypi_0    pypi
opencv-python             4.5.4.60                 pypi_0    pypi
openssl                   1.1.1w               h7f8727e_0
oss2                      2.18.3                   pypi_0    pypi
packaging                 23.2                     pypi_0    pypi
pandas                    1.1.5                    pypi_0    pypi
pillow                    8.4.0                    pypi_0    pypi
pip                       23.3.1           py38h06a4308_0
protobuf                  3.20.1                   pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
py-spy                    0.3.14                   pypi_0    pypi
pyarrow                   14.0.1                   pypi_0    pypi
pyasn1                    0.5.1                    pypi_0    pypi
pyasn1-modules            0.3.0                    pypi_0    pypi
pybind11                  2.11.1                   pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
pycryptodome              3.19.0                   pypi_0    pypi
pydantic                  1.10.9                   pypi_0    pypi
pydantic-core             2.14.5                   pypi_0    pypi
pydicom                   1.2.2                    pypi_0    pypi
pykmonitor                1.0                      pypi_0    pypi
pynvml                    11.5.0                   pypi_0    pypi
pyodps-int                0.11.5                   pypi_0    pypi
pyparsing                 3.1.1                    pypi_0    pypi
pytest-runner             6.0.0                    pypi_0    pypi
python                    3.8.13               h12debd9_0
python-dateutil           2.8.2                    pypi_0    pypi
pytz                      2023.3.post1             pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.1.2                h7f8727e_1
redis                     5.0.1                    pypi_0    pypi
regex                     2023.10.3                pypi_0    pypi
requests                  2.22.0                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
retrying                  1.3.4                    pypi_0    pypi
rmm                       23.10.0                  pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
scikit-learn              0.24.2                   pypi_0    pypi
scipy                     1.7.3                    pypi_0    pypi
sentencepiece             0.1.96                   pypi_0    pypi
setuptools                68.0.0           py38h06a4308_0
simplejson                3.17.6                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
sklearn                   0.0.post12               pypi_0    pypi
sqlite                    3.38.5               hc218d9a_0
sympy                     1.12                     pypi_0    pypi
tbb                       2021.11.0                pypi_0    pypi
tensorboard               2.14.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
thop-statistics           0.1.1-2303141613          pypi_0    pypi
threadpoolctl             3.2.0                    pypi_0    pypi
thrift                    0.16.0                   pypi_0    pypi
tk                        8.6.12               h1ccaba5_0
torch                     2.1.0+cu121              pypi_0    pypi
torchaudio                2.1.0+cu121              pypi_0    pypi
torchmetrics              1.2.1                    pypi_0    pypi
torchvision               0.16.0+cu121             pypi_0    pypi
tornado                   6.1                      pypi_0    pypi
tqdm                      4.66.1                   pypi_0    pypi
transformer-engine        1.0.0+66d91d5            pypi_0    pypi
transitions               0.9.0                    pypi_0    pypi
triton                    2.1.0                    pypi_0    pypi
typing-extensions         4.8.0                    pypi_0    pypi
urllib3                   1.25.11                  pypi_0    pypi
werkzeug                  3.0.1                    pypi_0    pypi
wheel                     0.41.2           py38h06a4308_0
xformers                  0.0.22.post4             pypi_0    pypi
xz                        5.2.5                h7f8727e_1
zipp                      3.17.0                   pypi_0    pypi
zlib                      1.2.12               h7f8727e_2

Additional context Add any other context about the problem here.

About this issue

  • Original URL
  • State: closed
  • Created 7 months ago
  • Comments: 17 (11 by maintainers)

Commits related to this issue

Most upvoted comments

Great, thanks! In the end we are going with the code in #1407 which I hope very much also works identically, if you could confirm that would be wonderful.

It works fine.

Can you try if the code in #1408 works for you @li-yi-dong?

I works pretty smooth with my task. And the RMM really outperforms the PyTorch caching allocator in terms of fragmentation.

I was able to reproduce running with four GPUs, I have yet to figure out what is going on. Debugging under gdb is difficult here because torchrun is running things in processes, but. If we run in gdb with set detach-on-fork off and set follow-fork-mode child, eventually we can get to the relevant process and I can get a backtrace.

Next step is to build RMM in debug mode so I have some symbols to inspect.

This is what I have right now to debug, note I only need to allocate things on a single device:

import os
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

device = (int(os.environ['LOCAL_RANK']))
if device == 3:
    torch.cuda.change_current_allocator(rmm_torch_allocator)
    rmm._cuda.gpu.setDevice(device)
    pool = rmm.mr.PoolMemoryResource(
        rmm.mr.CudaMemoryResource(),
        initial_pool_size=2**30,
    )
    rmm.mr.set_per_device_resource(device, pool)
    tensor = torch.zeros(2, 3, device=f"cuda:{device}")
    print(torch.cuda.current_device(), device, os.getpid(), flush=True)
    print(tensor, flush=True)
    del tensor

So my suspicion is that torch shuffling cuda devices out from under us in a bad way.

Can you try if the code in #1408 works for you @li-yi-dong?