numpyro: Memory leak in MCMC?

Hi there,

I have an issue: I’m trying to serve a Numpyro model using mlflow and mlserver: my model has varying input sizes and needs to reestimate all parameters regularly. I’m using MCMC to estimate those but the memory linearly increases each time it is called.

I’ve seen https://github.com/pyro-ppl/numpyro/issues/1347 but it doesn’t fix it in my case (as the size of the inputs is changing). Is this a bug? Is there a way to force releasing memory manually? Thanks!

numpyro version: 0.13.2 jax version: 0.4.23

Code to reproduce (taken from https://github.com/pyro-ppl/numpyro/issues/1347).

from collections import Counter

import numpy as np

import numpyro, jax
from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS
import os
import psutil

process = psutil.Process(os.getpid())

print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)


def model(y_obs):
    mu = sample('mu', dist.Normal(0., 1.))
    sigma = sample("sigma", dist.HalfCauchy(3.))
    y = sample("y", dist.Normal(mu, sigma), obs=y_obs)


for i in range(10):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True)
    mcmc.run(random.PRNGKey(0), np.zeros((np.random.randint(0, 10),)))
    print("\nGC OBJECTS:")
    cnt = Counter()
    # force collection; it is expected that count of different types
    # should not increase per iteration
    gc.collect()
    for x in gc.get_objects():
        if isinstance(x, list):
            if len(x) > 1:
                cnt[type(x[0])] += 1
    print(cnt.most_common(10))
    
    memory_in_mb = process.memory_info().rss / 1024 / 1024
    
    print(i, memory_in_mb)

About this issue

  • Original URL
  • State: open
  • Created 6 months ago
  • Comments: 15 (10 by maintainers)

Most upvoted comments

Huge thanks, @tare! I think we can make a jax issue with your reproducible code.

@fehiepsi I did run a couple of experiments without NumPyro to see the effects of jax.jit(), jax.clear_caches(), and dynamic shapes. Did you mean something like that?

These results suggest that the issue is not due to NumPyro. However, I don’t know know the internals of JAX enough to explain what is happening here.

w/out jax.jit(), w/out jax.clear_caches(), and static shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(10))

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 99.2421875
i=100: 99.27734375
i=200: 99.27734375
i=300: 99.27734375
i=400: 99.27734375
i=500: 99.27734375
i=600: 99.27734375
i=700: 99.28125
i=800: 99.28125
i=900: 99.28125
i=1000: 99.28125

w/out jax.jit() and w/ jax.clear_caches(), and static shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(10))

    jax.clear_caches()

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 97.3125
i=100: 107.6484375
i=200: 108.23046875
i=300: 108.65625
i=400: 108.69140625
i=500: 108.80078125
i=600: 109.14453125
i=700: 109.15234375
i=800: 109.15625
i=900: 109.1796875
i=1000: 109.18359375

w/out jax.jit(), w/out jax.clear_caches(), and dynamic shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(i))

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 91.18359375
i=100: 246.37109375
i=200: 390.8203125
i=300: 534.91015625
i=400: 679.5
i=500: 823.40625
i=600: 967.95703125
i=700: 1101.5234375
i=800: 1165.6484375
i=900: 1229.5859375
i=1000: 1291.671875

w/out jax.jit(), w/ jax.clear_caches(), and dynamic shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(i))

    jax.clear_caches()

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 91.578125
i=100: 109.5234375
i=200: 111.6171875
i=300: 103.02734375
i=400: 104.6015625
i=500: 105.97265625
i=600: 107.390625
i=700: 108.90234375
i=800: 110.33984375
i=900: 111.76953125
i=1000: 113.30078125

w/ jax.jit(), w/out jax.clear_caches(), and static shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

@jax.jit
def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(10))

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 97.42578125
i=100: 97.4453125
i=200: 97.44921875
i=300: 97.44921875
i=400: 97.44921875
i=500: 97.453125
i=600: 97.453125
i=700: 97.453125
i=800: 97.453125
i=900: 97.453125
i=1000: 97.453125

w/ jax.jit(), w/ jax.clear_caches(), and static shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

@jax.jit
def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(10))

    jax.clear_caches()

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 97.17578125
i=100: 104.63671875
i=200: 106.3984375
i=300: 106.44140625
i=400: 106.54296875
i=500: 106.5546875
i=600: 106.5625
i=700: 106.58203125
i=800: 106.5859375
i=900: 106.76171875
i=1000: 106.76953125

w/ jax.jit(), w/out jax.clear_caches(), and dynamic shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

@jax.jit
def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(i))

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 90.64453125
i=100: 200.91796875
i=200: 303.6953125
i=300: 403.40234375
i=400: 505.61328125
i=500: 604.90234375
i=600: 705.3515625
i=700: 804.84765625
i=800: 906.359375
i=900: 1005.77734375
i=1000: 1106.8046875

w/ jax.jit(), w/ jax.clear_caches(), and dynamic shapes

Code
import jax
import jax.numpy as jnp
import os
import psutil
import gc

process = psutil.Process(os.getpid())

print("jax version: ", jax.__version__)

@jax.jit
def fn(x):
  return jnp.sum(jnp.sin(x))

for i in range(0, 1_001):
    res = fn(jnp.zeros(i))

    jax.clear_caches()

    del res
    for x in jax.live_arrays():
        x.delete()

    gc.collect()

    if i % 100 == 0:
      print(f"{i=}: {process.memory_info().rss / 1024 / 1024}")
Output
jax version:  0.4.23
i=0: 90.734375
i=100: 108.7265625
i=200: 111.09375
i=300: 112.77734375
i=400: 114.72265625
i=500: 116.40234375
i=600: 118.37890625
i=700: 119.9921875
i=800: 121.80078125
i=900: 123.40625
i=1000: 125.1953125

This seems to solve the issue.