numpyro: Memory leak in MCMC?
Hi there,
I have an issue: I’m trying to serve a Numpyro model using mlflow and mlserver: my model has varying input sizes and needs to reestimate all parameters regularly. I’m using MCMC to estimate those but the memory linearly increases each time it is called.
I’ve seen https://github.com/pyro-ppl/numpyro/issues/1347 but it doesn’t fix it in my case (as the size of the inputs is changing). Is this a bug? Is there a way to force releasing memory manually? Thanks!
numpyro version: 0.13.2 jax version: 0.4.23
Code to reproduce (taken from https://github.com/pyro-ppl/numpyro/issues/1347).
from collections import Counter
import numpy as np
import numpyro, jax
from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS
import os
import psutil
process = psutil.Process(os.getpid())
print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)
def model(y_obs):
mu = sample('mu', dist.Normal(0., 1.))
sigma = sample("sigma", dist.HalfCauchy(3.))
y = sample("y", dist.Normal(mu, sigma), obs=y_obs)
for i in range(10):
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True)
mcmc.run(random.PRNGKey(0), np.zeros((np.random.randint(0, 10),)))
print("\nGC OBJECTS:")
cnt = Counter()
# force collection; it is expected that count of different types
# should not increase per iteration
gc.collect()
for x in gc.get_objects():
if isinstance(x, list):
if len(x) > 1:
cnt[type(x[0])] += 1
print(cnt.most_common(10))
memory_in_mb = process.memory_info().rss / 1024 / 1024
print(i, memory_in_mb)
About this issue
- Original URL
- State: open
- Created 6 months ago
- Comments: 15 (10 by maintainers)
Huge thanks, @tare! I think we can make a jax issue with your reproducible code.
@fehiepsi I did run a couple of experiments without NumPyro to see the effects of
jax.jit(),jax.clear_caches(), and dynamic shapes. Did you mean something like that?These results suggest that the issue is not due to NumPyro. However, I don’t know know the internals of JAX enough to explain what is happening here.
w/out
jax.jit(), w/outjax.clear_caches(), and static shapesCode
Output
w/out
jax.jit()and w/jax.clear_caches(), and static shapesCode
Output
w/out
jax.jit(), w/outjax.clear_caches(), and dynamic shapesCode
Output
w/out
jax.jit(), w/jax.clear_caches(), and dynamic shapesCode
Output
w/
jax.jit(), w/outjax.clear_caches(), and static shapesCode
Output
w/
jax.jit(), w/jax.clear_caches(), and static shapesCode
Output
w/
jax.jit(), w/outjax.clear_caches(), and dynamic shapesCode
Output
w/
jax.jit(), w/jax.clear_caches(), and dynamic shapesCode
Output
This seems to solve the issue.