equinox: Possible memory leak?

@patrick-kidger as you are already aware of the Mistral port I have been working on. I have now put that code in a public repo. There are some TODOs remaining that I will complete anyway but I am dealing with a big issue. Here is the condensed code implementation for reference.

import gc
import jax
import numpy as np
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from collections import namedtuple


# Utility to convert dtypes
def to_dtype(model, dtype):
     def _to_dtype(leaf):
        if isinstance(leaf, jax.Array):  # not eqx.is_array, which also detects NumPy arrays
            leaf_with_dtype = leaf.astype(dtype)
            del leaf
            gc.collect()  # just in case?
            return leaf_with_dtype
        else:
            return leaf
     return jtu.tree_map(_to_dtype, model)


def count_jax_parameters(model):
    return sum(x.size for x in jtu.tree_leaves(eqx.filter(model, eqx.is_array)))


# 1. RoPE
def precompute_frequencies(dim, max_pos, theta=10000.0):
    inv_freq = 1.0 / (
        theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32)[: (dim // 2)] / dim)
    )
    t = jnp.arange(0, max_pos, dtype=jnp.float32)
    freqs = jnp.outer(t, inv_freq)
    return jnp.cos(freqs), jnp.sin(freqs)


def calculate_rope(x, cos_freq, sin_freq, offset=0):
    # x shape  is [seqlen, num_heads, head_dim]

    # Get the sequence length
    seqlen = x.shape[0]

    # Get the corresponding positional embeddings
    sin = sin_freq[offset : offset + seqlen, :]
    cos = cos_freq[offset : offset + seqlen, :]

    # Positional embeddings are 2D while our input is 3D
    # if `num_heads` dimension is present in the inputs.
    # We need to add another dimension to our positional embeddings
    sin = sin[:, jnp.newaxis, :]
    cos = cos[:, jnp.newaxis, :]

    # Get the even-odd positions from the inputs
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]

    # Matmul with the rotation matrix
    # [cos_nθ, -sin_nθ] [x1]
    # [sin_nθ,  cos_nθ] [x2]
    # => [x1 * cos_nθ - x2 * sin_nθ, x1 * sin_nθ + x2 * cos_nθ]
    pos_embed = jnp.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1)
    pos_embed = jax.lax.collapse(pos_embed, -2)
    return pos_embed.astype(x.dtype)


# 2. Attention layer
class Attention(eqx.Module):
    n_heads: int
    n_kv_heads: int
    sliding_window: int
    scale: float
    kv_repeats: int
    head_dim: int
    wq: eqx.nn.Linear
    wk: eqx.nn.Linear
    wv: eqx.nn.Linear
    wo: eqx.nn.Linear
    
    def __init__(self, args:namedtuple, key:jax.Array):
        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_kv_heads
        self.kv_repeats = self.n_heads // self.n_kv_heads
        self.sliding_window = args.sliding_window
        self.scale = args.head_dim **-0.5
        self.head_dim = args.head_dim

        key1, key2, key3, key4 = jax.random.split(key, 4)
        self.wq = eqx.nn.Linear(args.dim, args.n_heads * args.head_dim, use_bias=False, key=key1)
        self.wk = eqx.nn.Linear(args.dim, args.n_kv_heads * args.head_dim, use_bias=False, key=key2)
        self.wv = eqx.nn.Linear(args.dim, args.n_kv_heads * args.head_dim,use_bias=False, key=key3)
        self.wo = eqx.nn.Linear(args.n_heads * args.head_dim, args.dim, use_bias=False, key=key4)

    def __call__(self, x, cos_freq, sin_freq, positions, mask):
        seqlen = x.shape[0]

        xq = jax.vmap(self.wq)(x)
        xk = jax.vmap(self.wk)(x)
        xv = jax.vmap(self.wv)(x)

        xq = jnp.reshape(xq, (seqlen, self.n_heads, self.head_dim))
        xk = jnp.reshape(xk, (seqlen, self.n_kv_heads, self.head_dim))
        xv = jnp.reshape(xv, (seqlen, self.n_kv_heads, self.head_dim))

        xq = calculate_rope(xq, cos_freq, sin_freq, 0)
        xk = calculate_rope(xk, cos_freq, sin_freq, 0)

        if positions.shape[0] > 1:
            # prefill
            key = jnp.repeat(xk, self.kv_repeats, axis=1)
            value = jnp.repeat(xv, self.kv_repeats, axis=1)
        # TODO: else fill from cache

        query = jnp.transpose(xq, (1, 0, 2)) # [seqlen, num_heads, head_dim] -> [num_heads, seqlen, head_dim]
        key = jnp.transpose(key, (1, 0, 2)) # [seqlen, num_heads, head_dim] -> [num_heads, seqlen, head_dim]
        value = jnp.transpose(value, (1, 0, 2)) # [seqlen, num_heads, head_dim] -> [num_heads, seqlen, head_dim]

        # scores : [n_heads, seqlen | 1, seqlen]
        scores = jnp.matmul(query, jnp.transpose(key, (0, 2, 1))) * self.scale

        if mask is not None:
            # Mask will of shape [seqlen, seqlen] but our scores
            # have shape [num_heads, seqlen, seqlen], hence we need
            # to introduce another dimension in the mask
            mask = mask[jnp.newaxis, ...]
            scores = scores + mask

        scores = jax.nn.softmax(scores.astype(jnp.float32)).astype(query.dtype)
        output = jnp.matmul(scores, value)
        output = jnp.transpose(output, (0, 2, 1))
        output = jnp.reshape(output, (output.shape[-1], -1))
        output = jax.vmap(self.wo)(output)
        return output


# 3. FeedForward
class FeedForward(eqx.Module):
    w1: eqx.nn.Linear
    w2: eqx.nn.Linear
    w3: eqx.nn.Linear

    def __init__(self, args, key):
        super().__init__()
        key1, key2, key3 = jax.random.split(key, 3)

        self.w1 = eqx.nn.Linear(args.dim, args.hidden_dim, use_bias=False, key=key1)
        self.w2 = eqx.nn.Linear(args.hidden_dim, args.dim, use_bias=False, key=key2)
        self.w3 = eqx.nn.Linear(args.dim, args.hidden_dim, use_bias=False, key=key3)

    def __call__(self, x):
        return self.w2(jax.nn.silu(self.w1(x)) * self.w3(x))


# 4. TransformerBlock
class TransformerBlock(eqx.Module):
    dim: int
    n_heads: int
    attention: Attention
    attention_norm: eqx.nn.RMSNorm
    feed_forward: FeedForward
    ffn_norm: eqx.nn.RMSNorm

    def __init__(self, args, key):
        key1, key2 = jax.random.split(key, 2)
        self.n_heads = args.n_heads
        self.dim = args.dim

        self.attention = Attention(args, key=key1)
        self.attention_norm = eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True)

        self.feed_forward = FeedForward(args, key=key2)
        self.ffn_norm = eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True)

    def __call__(self, x, cos_freq, sin_freq, positions, mask):
        normed_x = jax.vmap(self.attention_norm)(x.astype(jnp.float32)).astype(jnp.float16)
        r = self.attention(normed_x, cos_freq, sin_freq, positions, mask)
        h1 = x + r
        h2 = jax.vmap(self.ffn_norm)(h1.astype(jnp.float32)).astype(jnp.float16)
        h2 = jax.vmap(self.feed_forward)(h2)
        out = h1 + h2
        return out


# 5. Transformer
class Transformer(eqx.Module):
    tok_embeddings: eqx.nn.Embedding
    layers: TransformerBlock
    norm: eqx.nn.RMSNorm
    output: eqx.nn.Linear
    vocab_size: int
    n_layers: int
    
    def __init__(self, args, key):
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        keys = jax.random.split(key, args.n_layers + 2)
        embed_key, linear_key, tf_layers_keys = keys[0], keys[-1], keys[1:-1]
        
        self.tok_embeddings = eqx.nn.Embedding(args.vocab_size, args.dim, key=embed_key)
        self.norm = eqx.nn.RMSNorm(args.dim, eps=args.norm_eps, use_bias=False, use_weight=True)
        self.output = eqx.nn.Linear(args.dim, args.vocab_size, use_bias=False, key=linear_key)

        make_tf_layers = lambda k: TransformerBlock(args, key=k)
        self.layers = eqx.filter_vmap(make_tf_layers)(tf_layers_keys)

    def __call__(self, x, positions):
        # x is of shape (seqlen, ). We need to use vmap
        # as the embedding layer expects single token (scalar)
        # as input.
        h = jax.vmap(self.tok_embeddings)(x) # output shape: [seqlen, embed_size]
        sin_freq = precomputed_sin_freq[positions]
        cos_freq = precomputed_cos_freq[positions]
        
        if x.shape[-1] > 1: 
            seq_len = x.shape[-1]
            t = jnp.full((seq_len, seq_len), dtype=h.dtype, fill_value=1)
            mask = jnp.tril(t, k=0)
            # make the mask banded to account for sliding window
            mask = jnp.triu(mask, k=-args.sliding_window)
        else:
            mask = None

        # We need to call all the transformer blocks in a loop. Better to use lax.scan
        # as it would reduce compilation overhead and will be much faster.
        dynamic_tf_layers, static_tf_layers = eqx.partition(self.layers, eqx.is_array)
        
        def f(_x, _dynamic_tf_layers):
            tf_layer = eqx.combine(_dynamic_tf_layers, static_tf_layers)
            return tf_layer(_x, cos_freq, sin_freq, positions, mask), None

        h, _ = jax.lax.scan(f, h, dynamic_tf_layers)
        h = jax.vmap(self.norm)(h)
        h = jax.vmap(self.output)(h)
        # TODO: Calculate logits in this block
        return h


ModelArgs = namedtuple(
    "ModelArgs",
    [
        "dim",
        "n_layers",
        "hidden_dim",
        "n_heads",
        "head_dim",
        "n_kv_heads",
        "sliding_window",
        "norm_eps",
        "vocab_size",
        "max_batch_size"
    ]
)

# Same hparams as used in the original code
args = ModelArgs(
    dim=4096,
    n_layers=32,
    n_heads=32,
    n_kv_heads=8,
    head_dim=128,
    hidden_dim=14336,
    vocab_size=32000,
    max_batch_size=1,
    sliding_window=4096,
    norm_eps=1e-5
)

# Initialize the model
transformer = to_dtype(Transformer(args, key=jax.random.PRNGKey(1)), jnp.float16)

Given that this model has 7B parameters, I should be able to load this model on a A100 40G GPU as well as on TPUv3-8. The problem is that no matter what device, it always errors out with OOM.

I suspected that it might be the case that I am initializing the transformer with full precision, and then converting it to float16, so I tried the other way round where in each block I convert each layer to float16 as soon as it is initialized. Even after that it errors out with OOM. I suspect there is a memory leak somewhere but I am not able to pinpoint it.

Any suggestions/pointers would be very helpful. Thanks in advance 🙏

About this issue

  • Original URL
  • State: open
  • Created 4 months ago
  • Comments: 29 (14 by maintainers)

Most upvoted comments

Sure, go ahead and open an issue to track this. If there is some cunning way to do this without modifying the existing API then that would definitely be best.