numpyro: NUTS doesn't converge on a stan model
@stefanwebb pointed out that Pyro’s NUTS on the earnings latin square model gives extremely different results from Stan. cc. @jpchen
To debug this, I have tried running the model on Pyro’s NUTS and Numpyro’s NUTS and both return results which are very far off from Stan with high r_hat values indicating that the procedure hasn’t converged. Creating this issue to track progress on investigating this bug / discrepancy.
Some notes:
- It is possible that my translation of the model to Pyro/Numpyro is wrong.
- I have tried changing the default tensor type to Double in Pyro and changing the scale parameters to
dist.HalfCauchy(1.)(instead ofdist.Uniform(0, 100)which is more faithful to the Stan implementation) to see if that helps convergence. ~While it does seem to help somewhat, we still get very different results.~ This seems to help quite a lot on Numpyro (still checking on Pyro). - Numpyro is much faster than Pyro (I think also faster than Stan), but seems to give incorrect results. Not surprising since the underlying issue, either in my code or in the inference algorithm, is likely the same for both the implementations.
Pyro code:
import csv
from collections import defaultdict
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
torch.set_default_tensor_type('torch.DoubleTensor')
use_uniform = False
def scale():
return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)
def model(data):
eth = data['eth']
age = data['age']
x = data['x']
y = data['y']
mu_a1 = pyro.sample('mu_a1', dist.Normal(0., 1.))
mu_a2 = pyro.sample('mu_a2', dist.Normal(0., 1.))
sigma_a1 = pyro.sample('sigma_a1', scale())
sigma_a2 = pyro.sample('sigma_a2', scale())
mu_b1 = pyro.sample('mu_b1', dist.Normal(0., 1.))
mu_b2 = pyro.sample('mu_b2', dist.Normal(0., 1.))
sigma_b1 = pyro.sample('sigma_b1', scale())
sigma_b2 = pyro.sample('sigma_b2', scale())
mu_c = pyro.sample('mu_c', dist.Normal(0., 1.))
sigma_c = pyro.sample('sigma_c', scale())
mu_d = pyro.sample('mu_d', dist.Normal(0., 1.))
sigma_d = pyro.sample('sigma_d', scale())
nage = pyro.plate("n_age", 3, dim=-1)
neth = pyro.plate("neth", 4, dim=-2)
with neth:
a1 = pyro.sample('a1', dist.Normal(10 * mu_a1, sigma_a1))
a2 = pyro.sample('a2', dist.Normal(mu_a2, sigma_a2))
with nage:
b1 = pyro.sample('b1', dist.Normal(10 * mu_b1, sigma_b1))
b2 = pyro.sample('b2', dist.Normal(0.1 * mu_b2, sigma_b2))
with neth, nage:
c = pyro.sample('c', dist.Normal(10 * mu_c, sigma_c))
d = pyro.sample('d', dist.Normal(0.1 * mu_d, sigma_d))
y_hat = a1[eth].squeeze(-1) + a2[eth].squeeze(-1) * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
simga_y = pyro.sample('sigma_y', scale())
with pyro.plate('N', 1059):
pyro.sample('obs', dist.Normal(y_hat, simga_y), obs=y)
data = defaultdict(list)
with open('earnings.csv', 'r') as f:
csv_reader = csv.DictReader(f)
for row in csv_reader:
data['x'].append(float(row['x']))
data['y'].append(float(row['y']))
data['age'].append(int(row['age']) - 1)
data['eth'].append(int(row['eth']) - 1)
data['x'] = torch.tensor(data['x'])
data['y'] = torch.tensor(data['y'])
data['age'] = torch.tensor(data['age'], dtype=torch.long)
data['eth'] = torch.tensor(data['eth'], dtype=torch.long)
nuts_kernel = NUTS(model, max_tree_depth=6, jit_compile=True, ignore_jit_warnings=True)
posterior_fully_pooled = MCMC(nuts_kernel,
num_samples=500,
warmup_steps=500,
num_chains=2).run(data)
print(posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2']).diagnostics())
marginals = posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2'])
for k, v in marginals.empirical.items():
print(k, v.mean)
Numpyro code:
import csv
from collections import defaultdict
from jax.random import PRNGKey
import numpyro.distributions as dist
from numpyro.handlers import sample
import jax.numpy as np
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect
use_uniform = False
def scale():
return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)
def model(data):
eth = data['eth']
age = data['age']
x = data['x']
y = data['y']
mu_a1 = sample('mu_a1', dist.Normal(0., 1.))
mu_a2 = sample('mu_a2', dist.Normal(0., 1.))
sigma_a1 = sample('sigma_a1', scale())
sigma_a2 = sample('sigma_a2', scale())
mu_b1 = sample('mu_b1', dist.Normal(0., 1.))
mu_b2 = sample('mu_b2', dist.Normal(0., 1.))
sigma_b1 = sample('sigma_b1', scale())
sigma_b2 = sample('sigma_b2', scale())
mu_c = sample('mu_c', dist.Normal(0., 1.))
sigma_c = sample('sigma_c', scale())
mu_d = sample('mu_d', dist.Normal(0., 1.))
sigma_d = sample('sigma_d', scale())
a1 = sample('a1', dist.Normal(10 * np.broadcast_to(mu_a1, (4,)), sigma_a1))
a2 = sample('a2', dist.Normal(np.broadcast_to(mu_a2, (4,)), sigma_a2))
b1 = sample('b1', dist.Normal(10 * np.broadcast_to(mu_b1, (3,)), sigma_b1))
b2 = sample('b2', dist.Normal(0.1 * np.broadcast_to(mu_b2, (3,)), sigma_b2))
c = sample('c', dist.Normal(10 * np.broadcast_to(mu_c, (4, 3)), sigma_c))
d = sample('d', dist.Normal(0.1 * np.broadcast_to(mu_d, (4, 3)), sigma_d))
y_hat = a1[eth] + a2[eth] * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
simga_y = sample('sigma_y', scale())
sample('obs', dist.Normal(y_hat, simga_y), obs=y)
data = defaultdict(list)
with open('earnings.csv', 'r') as f:
csv_reader = csv.DictReader(f)
for row in csv_reader:
data['x'].append(float(row['x']))
data['y'].append(float(row['y']))
data['age'].append(int(row['age']) - 1)
data['eth'].append(int(row['eth']) - 1)
data['x'] = np.array(data['x'])
data['y'] = np.array(data['y'])
data['age'] = np.array(data['age']).astype(np.int64)
data['eth'] = np.array(data['eth']).astype(np.int64)
init_params, potential_fn, transform_fn = initialize_model(PRNGKey(0), model, data)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, 2000)
hmc_states = fori_collect(2000, sample_kernel, hmc_state,
transform=lambda hmc_state: transform_fn(hmc_state.z))
print(hmc_states)
Stan results
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
a1[1] 5.60 1.57 8.55 -9.18 -1.05 5.24 11.48 22.41 30 1.14
a1[2] 5.44 1.58 8.57 -9.37 -1.17 5.14 11.46 22.07 29 1.14
a1[3] 5.00 1.59 8.60 -9.92 -1.64 4.65 11.04 21.43 29 1.14
a1[4] 5.48 1.57 8.55 -9.31 -1.16 5.14 11.46 22.14 30 1.14
a2[1] 0.07 0.03 0.20 -0.34 -0.06 0.07 0.19 0.42 34 1.06
a2[2] 0.07 0.03 0.20 -0.34 -0.05 0.07 0.19 0.42 33 1.06
a2[3] 0.08 0.03 0.20 -0.33 -0.04 0.08 0.20 0.44 33 1.07
a2[4] 0.07 0.03 0.20 -0.33 -0.05 0.07 0.20 0.43 33 1.06
b1[1] 3.97 1.52 8.68 -14.80 -1.11 3.81 9.56 20.60 33 1.14
b1[2] 2.36 1.52 8.66 -16.69 -2.77 2.25 8.05 18.87 33 1.14
b1[3] 1.78 1.52 8.69 -17.57 -3.48 1.64 7.32 18.27 33 1.13
b2[1] -0.01 0.02 0.16 -0.29 -0.09 -0.02 0.06 0.38 60 1.03
b2[2] 0.02 0.02 0.16 -0.25 -0.06 0.01 0.09 0.40 59 1.03
b2[3] 0.02 0.02 0.16 -0.24 -0.06 0.02 0.10 0.42 59 1.03
c[1,1] -2.43 2.17 7.72 -14.76 -7.90 -3.95 2.13 13.98 13 1.36
c[1,2] -2.36 2.17 7.72 -14.89 -7.85 -3.87 2.21 14.05 13 1.36
c[1,3] -2.46 2.17 7.72 -14.83 -7.98 -4.04 2.13 13.95 13 1.36
c[2,1] -2.37 2.17 7.72 -14.80 -7.87 -3.89 2.23 14.05 13 1.36
c[2,2] -2.51 2.17 7.73 -14.93 -8.04 -4.02 2.11 14.02 13 1.36
c[2,3] -2.38 2.17 7.72 -14.79 -7.89 -3.87 2.24 13.83 13 1.36
c[3,1] -2.41 2.17 7.72 -14.89 -7.89 -3.89 2.23 13.97 13 1.36
c[3,2] -2.44 2.17 7.72 -14.80 -7.94 -3.92 2.19 14.01 13 1.36
c[3,3] -2.43 2.17 7.72 -14.75 -7.92 -3.93 2.19 13.97 13 1.36
c[4,1] -2.44 2.17 7.72 -14.85 -7.94 -3.97 2.17 13.99 13 1.36
c[4,2] -2.38 2.17 7.72 -14.75 -7.86 -3.88 2.19 14.04 13 1.36
c[4,3] -2.42 2.17 7.72 -14.80 -7.91 -3.93 2.19 13.95 13 1.36
d[1,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[1,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[1,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[2,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[2,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[2,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[3,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[3,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[3,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[4,1] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[4,2] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
d[4,3] -0.02 0.02 0.11 -0.22 -0.10 -0.02 0.06 0.21 29 1.11
mu_a1 0.54 0.16 0.85 -0.94 -0.12 0.50 1.13 2.17 30 1.14
mu_a2 0.07 0.03 0.20 -0.34 -0.05 0.07 0.20 0.43 33 1.06
mu_b1 0.25 0.13 0.85 -1.58 -0.27 0.24 0.79 1.87 42 1.11
mu_b2 0.04 0.09 0.98 -1.90 -0.61 0.03 0.68 2.06 125 1.03
mu_c -0.24 0.22 0.77 -1.48 -0.80 -0.40 0.22 1.40 13 1.36
mu_d -0.17 0.21 1.14 -2.22 -0.98 -0.22 0.63 2.09 29 1.11
sigma_a1 0.96 0.11 1.92 0.02 0.12 0.34 1.01 5.30 319 1.01
sigma_a2 0.01 0.00 0.03 0.00 0.00 0.01 0.02 0.08 328 1.01
sigma_b1 4.07 0.28 6.07 0.14 1.14 2.19 4.43 20.92 487 1.02
sigma_b2 0.12 0.03 0.30 0.00 0.02 0.04 0.09 0.87 138 1.05
sigma_c 0.16 0.01 0.13 0.02 0.06 0.13 0.22 0.48 232 1.02
sigma_d 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.01 179 1.01
sigma_y 0.88 0.00 0.02 0.84 0.86 0.87 0.89 0.91 633 1.01
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 1
- Comments: 15 (15 by maintainers)
I think that we can close this now?
@neerajprad It turns out that the problem lies at
init_params. I used the same init params with Stan and get similar results. In Stan, params is initialized Uniformly in (-2, 2) interval. While here we used initial_trace. The initial_trace makes dependent latent variables such asa1,b1get wildly initial values. We might consider supporting the same behaviour as in Stan?The data and model look correct to me. And Stan did use Uniform prior for bounded parameters. I can’t think of any reason for why using Uniform didn’t work. Will take a closer look now. 😃
I think using
HalfCauchy(1.)instead ofUniform(0, 100)addresses some of these issues. Results on Numpyro using HalfCauchy with 2000 warmup and 2000 samples. The parameter values are at least in the vicinity of what we get from Stan.