pyro: MCMC does not work well with multi-chains in CPU due to Memory Error

The following script is copied exactly from Bayesian regression tutorial (except num_samples=3000 and num_chains=2). While running it, I get MemoryError. I have tried to debug this memory error by changing various settings (ulimit, shared memory segment) in my system but no hope. I get this issue in two systems which I have so I guess this will not happen for only me. If so, then this is a serious problem which can come from how we use torch.multiprocessing in Pyro.

import numpy as np
import pandas as pd
import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS

pyro.enable_validation(True)
pyro.set_rng_seed(1)
DATA_URL = "https://d2fefpcigoriu7.cloudfront.net/datasets/rugged_data.csv"
rugged_data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")

#torch.multiprocessing.set_sharing_strategy("file_system")


def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(8., 1000.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.iarange("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

        
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

nuts_kernel = NUTS(model, adapt_step_size=True)
hmc_posterior = MCMC(nuts_kernel, num_samples=4000, warmup_steps=1000, num_chains=4).run(is_cont_africa, ruggedness, log_gdp)

Here is its backtrack

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 15 (15 by maintainers)

Most upvoted comments

@fehiepsi - This seems to work fine now. Could you check this again on your system and close this issue if it is resolved?

@neerajprad , for the above example (with jit_compile=True, num_samples=500 to avoid Memory error, warmup_steps to 20 to reduce its effect on timing):

  • num_chains=2: 2s vs 5s. (averaging for all chains: 180its/s vs 90its/s)
  • num_chains=4: 3s vs 10s. (160its/s vs 50its/s)
  • num_chains=6: 5s vs 15s. (100its/s vs 33its/s)

I think we are hitting some system level limitations on shared memory. If you change the _ParallelSampler to yield None, args[1].. instead of the trace, you will find that no error is thrown. I think what is happening is that even though we read from the queue, the objects remain in shared memory and after a certain point we hit some limitation due to which the OS ends up killing the process.

One thing would be to find out what limitation are we hitting into and increase the limit in sysctl.conf. But I think we shouldn’t do that. I think this issue will resolve once #1725 is resolved, since we won’t be keeping the objects in shared memory but reducing it to some other representation alongside. This will ensure that our shared memory resource consumption remains fixed during the entire sampling.