jax: jax_threefry_partitionable + rematerialization doesn't seem to be working together in distributed training
Description
I have a transformer model where each transformer block is rematerialized. The model is distributed over multiple devices using jit. Each transformer block has dropout enabled.
To prevent rng implementation from inserting synchronization operations I’m also enabling jax_threefry_partitionable
as suggested in the doc.
Problem is, jax_threefry_partitionable
doesn’t seem to play nicely with rematerialization. As soon as I enable dropout, I get GPU OOM because JAX decides to preserve huge arrays containing rng key per activation tensor component for each transformer block, despite them being rematerialized. It should be possible for jax to reconstruct this key array from a single key during rematerialization, but it doesn’t seem to do that.
I’m happy to provide a repoduction if you can confirm that this is unexpected behavior. If not, can you please suggest a workaround? Currently it doesn’t seem possible to efficiently train large models with dropout.
A relevant discussion with OOM error message example here: https://github.com/google/flax/discussions/3090
What jax/jaxlib version are you using?
0.4.14
Which accelerator(s) are you using?
GPU
Additional system info
python3.10
NVIDIA GPU info
Fri Oct 6 14:22:06 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:8A:00.0 Off | 0 |
| N/A 35C P0 72W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 |
| N/A 31C P0 71W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:8C:00.0 Off | 0 |
| N/A 31C P0 72W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:8D:00.0 Off | 0 |
| N/A 35C P0 74W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:9C:00.0 Off | 0 |
| N/A 36C P0 76W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:9D:00.0 Off | 0 |
| N/A 31C P0 72W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:9E:00.0 Off | 0 |
| N/A 31C P0 71W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:9F:00.0 Off | 0 |
| N/A 35C P0 73W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
About this issue
- Original URL
- State: open
- Created 9 months ago
- Reactions: 9
- Comments: 16 (5 by maintainers)
@mattjj @froystig Happy new year, gentlemen! Do you think 2024 is the year when this bug finally got fixed? 😉
Looks like it. Interestingly, it also seems to fix another rng-related issue: https://github.com/google/jax/issues/19893
Btw, can you elaborate a bit on how does the
rng
implementation work when keys are sharded? E.g. does it require any additional communication?@hr0nix sorry that this slipped through the cracks. Thanks for the pings, everyone.
Can you check that this repros with jaxlib 0.4.20? IIRC there was one GPU-specific remat fix that happened recently, though I don’t have a link to it at the moment. EDIT: https://github.com/openxla/xla/pull/6527
I’ve made a repro for this bug. Turns out it has nothing to do with
jax_threefry_partitionable
, perfectly repoducible without it.Repo was made for A100 80Gb, so tensor shapes might need to be adjusted for a GPU with a different amount of memory.
./repro.py
— will fail because without rematerialization it needs ~122.75 Gb of GPU RAM./repro.py --remat
— works perfectly fine with remat, because it now needs just 63Gb of GPU RAM./repro.py --remat --dropout-rate 0.1
— OOMs again, requiring ~118Gb of GPU RAM. From looking at peak buffers it becomes clear that the dropout mask is not being rematerialized: tensors correponding to full dropout masks for different layers are occupying memory.Repro code: