jax: memory leak of jax.random in GPU

@ziatdinovmax was observing a phenomenon that memory keeps increasing until crashing for some NumPyro models (specifically, the issue https://github.com/pyro-ppl/numpyro/issues/447) when they are run in GPU. Here is an isolated repro code, which only uses JAX ops:

from jax import random, lax

def model(key):
    D_X, D_H, D_Y = 3, 5, 1
    # sample first layer (we put unit normal priors on all weights)
    key, subkey = random.split(key)
    w1 = random.normal(subkey, shape=(D_X, D_H))
    key, subkey = random.split(key)
    w1 = random.uniform(subkey, shape=w1.shape, minval=-2, maxval=2)

    # sample second layer
    key, subkey = random.split(key)
    w2 = random.normal(subkey, shape=(D_H, D_H))
    key, subkey = random.split(key)
    w2 = random.uniform(subkey, shape=w2.shape, minval=-2, maxval=2)

    # sample final layer of weights and neural network output
    key, subkey = random.split(key)
    w3 = random.normal(subkey, shape=(D_H, D_Y))
    key, subkey = random.split(key)
    w3 = random.uniform(subkey, shape=w3.shape, minval=-2, maxval=2)

    # we put a prior on the observation noise
    key, subkey = random.split(key)
    prec_obs = random.normal(subkey)
    key, subkey = random.split(key)
    prec_obs = random.uniform(subkey, shape=prec_obs.shape, minval=-2, maxval=2)
    return w1, w2, w3, prec_obs

def cond_fn(state):
    return state[0] < 100

def body_fn(state):
    i, key, _ = state
    key, subkey = random.split(key)
    return i + 1, key, model(subkey)

init_state = (0, random.PRNGKey(0), model(random.PRNGKey(2019)))
i, key, params = lax.while_loop(cond_fn, body_fn, init_state)

Interesting, if I use rolled loop in threefry_2x32, the issue does not happen. Could it be a hint?

cc @mattjj 😃

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 16 (16 by maintainers)

Most upvoted comments

PR #1756 makes the example in this issue compile quickly and without the memory blowup. Note the fix requires a rebuild of jaxlib (or requires waiting for us to release new jaxlib wheels.) Hope that helps!

@fehiepsi I think this is simply that XLA is taking a lot of memory to compile that computation. I raised an issue with the XLA team, although I suspect we are going to need to move the PRNG computation out of XLA and into a custom CUDA kernel as a short-term fix. PRNG compilation time/space is a constant pain point.

@neerajprad The experimental check_leaks support has bitrotted, so I wouldn’t trust the results. But if you can provide a small self-contained reproduction of the CPU memory problem, we can look into it.

No worries, please always eagerly report possible bugs, so we can work through them just like this! (We love when you guys raise issues.)

Also, let us know if the caches are growing too large in your use cases. We can make the caching logic smarter.

@hawkinsp should we leave this issue open pending some XLA:GPU compilation follow-up?

There is a leak in the JaxprTracer graph though, exactly where Dougal thought there might be.

I’ll follow up with a PR.