jax: jax_threefry_partitionable + rematerialization doesn't seem to be working together in distributed training

Description

I have a transformer model where each transformer block is rematerialized. The model is distributed over multiple devices using jit. Each transformer block has dropout enabled.

To prevent rng implementation from inserting synchronization operations I’m also enabling jax_threefry_partitionable as suggested in the doc.

Problem is, jax_threefry_partitionable doesn’t seem to play nicely with rematerialization. As soon as I enable dropout, I get GPU OOM because JAX decides to preserve huge arrays containing rng key per activation tensor component for each transformer block, despite them being rematerialized. It should be possible for jax to reconstruct this key array from a single key during rematerialization, but it doesn’t seem to do that.

I’m happy to provide a repoduction if you can confirm that this is unexpected behavior. If not, can you please suggest a workaround? Currently it doesn’t seem possible to efficiently train large models with dropout.

A relevant discussion with OOM error message example here: https://github.com/google/flax/discussions/3090

What jax/jaxlib version are you using?

0.4.14

Which accelerator(s) are you using?

GPU

Additional system info

python3.10

NVIDIA GPU info

Fri Oct  6 14:22:06 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8A:00.0 Off |                    0 |
| N/A   35C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:8B:00.0 Off |                    0 |
| N/A   31C    P0              71W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:8C:00.0 Off |                    0 |
| N/A   31C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   35C    P0              74W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:9C:00.0 Off |                    0 |
| N/A   36C    P0              76W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:9D:00.0 Off |                    0 |
| N/A   31C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:9E:00.0 Off |                    0 |
| N/A   31C    P0              71W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:9F:00.0 Off |                    0 |
| N/A   35C    P0              73W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

About this issue

  • Original URL
  • State: open
  • Created 9 months ago
  • Reactions: 9
  • Comments: 16 (5 by maintainers)

Most upvoted comments

@mattjj @froystig Happy new year, gentlemen! Do you think 2024 is the year when this bug finally got fixed? 😉

I understood your most recent comment to mean that you have a workaround. Is that right?

Looks like it. Interestingly, it also seems to fix another rng-related issue: https://github.com/google/jax/issues/19893

Btw, can you elaborate a bit on how does the rng implementation work when keys are sharded? E.g. does it require any additional communication?

@hr0nix sorry that this slipped through the cracks. Thanks for the pings, everyone.

Can you check that this repros with jaxlib 0.4.20? IIRC there was one GPU-specific remat fix that happened recently, though I don’t have a link to it at the moment. EDIT: https://github.com/openxla/xla/pull/6527

I’ve made a repro for this bug. Turns out it has nothing to do with jax_threefry_partitionable, perfectly repoducible without it.

Repo was made for A100 80Gb, so tensor shapes might need to be adjusted for a GPU with a different amount of memory.

./repro.py — will fail because without rematerialization it needs ~122.75 Gb of GPU RAM ./repro.py --remat — works perfectly fine with remat, because it now needs just 63Gb of GPU RAM ./repro.py --remat --dropout-rate 0.1 — OOMs again, requiring ~118Gb of GPU RAM. From looking at peak buffers it becomes clear that the dropout mask is not being rematerialized: tensors correponding to full dropout masks for different layers are occupying memory.

Peak buffers:
        Buffer 1:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_5._apply_block/Block_5/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================
        Buffer 2:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 3:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 4:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 5:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================
...

Repro code:

import functools

import click
import flax
import flax.linen as nn
import flax.training.train_state
import jax
import jax.numpy as jnp
import optax


class Dropout(nn.Module):
    rate: float

    @nn.compact
    def __call__(self, inputs, rng):
        if self.rate == 0.0:
            return inputs

        if self.rate == 1.0:
            return jnp.zeros_like(inputs)

        keep_prob = 1.0 - self.rate
        mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape)
        return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))


class Block(nn.Module):
    dim: int
    dropout_rate: float

    @nn.compact
    def __call__(self, input, rng):
        scale = 32  # We want large memory consumption without remat
        emb = nn.Dense(features=self.dim * scale)(input)
        emb = nn.relu(emb)
        emb = Dropout(rate=self.dropout_rate)(emb, rng)
        emb = nn.Dense(features=self.dim)(emb)
        return emb


class Model(nn.Module):
    dim: int
    dropout_rate: float
    num_layers: int
    remat: bool

    @nn.compact
    def __call__(self, input, rng):
        def _apply_block(block, block_input, rng):
            return block(block_input, rng)

        if self.remat:
            _apply_block = nn.checkpoint(
                _apply_block,
                policy=jax.checkpoint_policies.nothing_saveable,
                prevent_cse=True,
            )

        emb = input
        for _ in range(self.num_layers):
            rng, block_rng = jax.random.split(rng)
            block = Block(dim=self.dim, dropout_rate=self.dropout_rate)
            emb = _apply_block(block, emb, block_rng)

        return emb


def loss_fn(params, train_state, batch, rng):
    outputs = train_state.apply_fn(params, batch, rng)
    return jnp.mean(outputs * outputs)


@functools.partial(jax.jit, donate_argnames=("train_state",))
def train_step(train_state, batch, rng):
    grad_fn = jax.grad(loss_fn)
    grad = grad_fn(train_state.params, train_state, batch, rng)
    train_state = train_state.apply_gradients(grads=grad)
    return train_state


def make_batch(batch_size, dim):
    return jnp.zeros(shape=(batch_size, dim), dtype=jnp.float32)


@click.command()
@click.option("--dim", type=int, default=1024)
@click.option("--batch-size", type=int, default=8192)
@click.option("--dropout-rate", type=float, default=0.0)
@click.option("--num-layers", type=int, default=64)
@click.option("--remat", is_flag=True)
def main(
    dim: int,
    batch_size: int,
    dropout_rate: float,
    num_layers: int,
    remat: bool,
):
    model = Model(
        dim=dim, dropout_rate=dropout_rate, num_layers=num_layers, remat=remat
    )
    batch = make_batch(batch_size=batch_size, dim=dim)
    rng = jax.random.PRNGKey(0)
    params = model.init({"params": rng}, batch, rng)
    optimizer = optax.adam(learning_rate=1e-3)
    train_state = flax.training.train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=optimizer
    )
    train_state = train_step(train_state, batch, rng)


if __name__ == "__main__":
    main()