numpyro: NUTS doesn't converge on a stan model

@stefanwebb pointed out that Pyro’s NUTS on the earnings latin square model gives extremely different results from Stan. cc. @jpchen

To debug this, I have tried running the model on Pyro’s NUTS and Numpyro’s NUTS and both return results which are very far off from Stan with high r_hat values indicating that the procedure hasn’t converged. Creating this issue to track progress on investigating this bug / discrepancy.

Some notes:

  • It is possible that my translation of the model to Pyro/Numpyro is wrong.
  • I have tried changing the default tensor type to Double in Pyro and changing the scale parameters to dist.HalfCauchy(1.) (instead of dist.Uniform(0, 100) which is more faithful to the Stan implementation) to see if that helps convergence. ~While it does seem to help somewhat, we still get very different results.~ This seems to help quite a lot on Numpyro (still checking on Pyro).
  • Numpyro is much faster than Pyro (I think also faster than Stan), but seems to give incorrect results. Not surprising since the underlying issue, either in my code or in the inference algorithm, is likely the same for both the implementations.

Pyro code:

import csv
from collections import defaultdict

import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC


torch.set_default_tensor_type('torch.DoubleTensor')
use_uniform = False


def scale():
    return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)


def model(data):
    eth = data['eth']
    age = data['age']
    x = data['x']
    y = data['y']
    mu_a1 = pyro.sample('mu_a1', dist.Normal(0., 1.))
    mu_a2 = pyro.sample('mu_a2', dist.Normal(0., 1.))
    sigma_a1 = pyro.sample('sigma_a1', scale())
    sigma_a2 = pyro.sample('sigma_a2', scale())

    mu_b1 = pyro.sample('mu_b1', dist.Normal(0., 1.))
    mu_b2 = pyro.sample('mu_b2', dist.Normal(0., 1.))
    sigma_b1 = pyro.sample('sigma_b1', scale())
    sigma_b2 = pyro.sample('sigma_b2', scale())

    mu_c = pyro.sample('mu_c', dist.Normal(0., 1.))
    sigma_c = pyro.sample('sigma_c', scale())
    mu_d = pyro.sample('mu_d', dist.Normal(0., 1.))
    sigma_d = pyro.sample('sigma_d', scale())

    nage = pyro.plate("n_age", 3, dim=-1)
    neth = pyro.plate("neth", 4, dim=-2)

    with neth:
        a1 = pyro.sample('a1', dist.Normal(10 * mu_a1, sigma_a1))
        a2 = pyro.sample('a2', dist.Normal(mu_a2, sigma_a2))

    with nage:
        b1 = pyro.sample('b1', dist.Normal(10 * mu_b1, sigma_b1))
        b2 = pyro.sample('b2', dist.Normal(0.1 * mu_b2, sigma_b2))

    with neth, nage:
        c = pyro.sample('c', dist.Normal(10 * mu_c, sigma_c))
        d = pyro.sample('d', dist.Normal(0.1 * mu_d, sigma_d))

    y_hat = a1[eth].squeeze(-1) + a2[eth].squeeze(-1) * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
    simga_y = pyro.sample('sigma_y', scale())
    with pyro.plate('N', 1059):
        pyro.sample('obs', dist.Normal(y_hat, simga_y), obs=y)


data = defaultdict(list)
with open('earnings.csv', 'r') as f:
    csv_reader = csv.DictReader(f)
    for row in csv_reader:
        data['x'].append(float(row['x']))
        data['y'].append(float(row['y']))
        data['age'].append(int(row['age']) - 1)
        data['eth'].append(int(row['eth']) - 1)


data['x'] = torch.tensor(data['x'])
data['y'] = torch.tensor(data['y'])
data['age'] = torch.tensor(data['age'], dtype=torch.long)
data['eth'] = torch.tensor(data['eth'], dtype=torch.long)


nuts_kernel = NUTS(model, max_tree_depth=6, jit_compile=True, ignore_jit_warnings=True)
posterior_fully_pooled = MCMC(nuts_kernel,
                              num_samples=500,
                              warmup_steps=500,
                              num_chains=2).run(data)

print(posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2']).diagnostics())
marginals = posterior_fully_pooled.marginal(['a1', 'a2', 'b1', 'b2'])
for k, v in marginals.empirical.items():
    print(k, v.mean)

Numpyro code:

import csv
from collections import defaultdict

from jax.random import PRNGKey

import numpyro.distributions as dist
from numpyro.handlers import sample
import jax.numpy as np

from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect

use_uniform = False


def scale():
    return dist.Uniform(0., 100.) if use_uniform else dist.HalfCauchy(1.)


def model(data):
    eth = data['eth']
    age = data['age']
    x = data['x']
    y = data['y']
    mu_a1 = sample('mu_a1', dist.Normal(0., 1.))
    mu_a2 = sample('mu_a2', dist.Normal(0., 1.))
    sigma_a1 = sample('sigma_a1', scale())
    sigma_a2 = sample('sigma_a2', scale())

    mu_b1 = sample('mu_b1', dist.Normal(0., 1.))
    mu_b2 = sample('mu_b2', dist.Normal(0., 1.))
    sigma_b1 = sample('sigma_b1', scale())
    sigma_b2 = sample('sigma_b2', scale())

    mu_c = sample('mu_c', dist.Normal(0., 1.))
    sigma_c = sample('sigma_c', scale())
    mu_d = sample('mu_d', dist.Normal(0., 1.))
    sigma_d = sample('sigma_d', scale())

    a1 = sample('a1', dist.Normal(10 * np.broadcast_to(mu_a1, (4,)), sigma_a1))
    a2 = sample('a2', dist.Normal(np.broadcast_to(mu_a2, (4,)), sigma_a2))

    b1 = sample('b1', dist.Normal(10 * np.broadcast_to(mu_b1, (3,)), sigma_b1))
    b2 = sample('b2', dist.Normal(0.1 * np.broadcast_to(mu_b2, (3,)), sigma_b2))

    c = sample('c', dist.Normal(10 * np.broadcast_to(mu_c, (4, 3)), sigma_c))
    d = sample('d', dist.Normal(0.1 * np.broadcast_to(mu_d, (4, 3)), sigma_d))

    y_hat = a1[eth] + a2[eth] * x + b1[age] + b2[age] * x + c[eth, age] + d[eth, age] * x
    simga_y = sample('sigma_y', scale())
    sample('obs', dist.Normal(y_hat, simga_y), obs=y)


data = defaultdict(list)
with open('earnings.csv', 'r') as f:
    csv_reader = csv.DictReader(f)
    for row in csv_reader:
        data['x'].append(float(row['x']))
        data['y'].append(float(row['y']))
        data['age'].append(int(row['age']) - 1)
        data['eth'].append(int(row['eth']) - 1)


data['x'] = np.array(data['x'])
data['y'] = np.array(data['y'])
data['age'] = np.array(data['age']).astype(np.int64)
data['eth'] = np.array(data['eth']).astype(np.int64)

init_params, potential_fn, transform_fn = initialize_model(PRNGKey(0), model, data)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, 2000)
hmc_states = fori_collect(2000, sample_kernel, hmc_state,
                          transform=lambda hmc_state: transform_fn(hmc_state.z))
print(hmc_states)

Stan results

               mean se_mean    sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
a1[1]          5.60    1.57  8.55   -9.18   -1.05    5.24   11.48   22.41    30 1.14
a1[2]          5.44    1.58  8.57   -9.37   -1.17    5.14   11.46   22.07    29 1.14
a1[3]          5.00    1.59  8.60   -9.92   -1.64    4.65   11.04   21.43    29 1.14
a1[4]          5.48    1.57  8.55   -9.31   -1.16    5.14   11.46   22.14    30 1.14
a2[1]          0.07    0.03  0.20   -0.34   -0.06    0.07    0.19    0.42    34 1.06
a2[2]          0.07    0.03  0.20   -0.34   -0.05    0.07    0.19    0.42    33 1.06
a2[3]          0.08    0.03  0.20   -0.33   -0.04    0.08    0.20    0.44    33 1.07
a2[4]          0.07    0.03  0.20   -0.33   -0.05    0.07    0.20    0.43    33 1.06
b1[1]          3.97    1.52  8.68  -14.80   -1.11    3.81    9.56   20.60    33 1.14
b1[2]          2.36    1.52  8.66  -16.69   -2.77    2.25    8.05   18.87    33 1.14
b1[3]          1.78    1.52  8.69  -17.57   -3.48    1.64    7.32   18.27    33 1.13
b2[1]         -0.01    0.02  0.16   -0.29   -0.09   -0.02    0.06    0.38    60 1.03
b2[2]          0.02    0.02  0.16   -0.25   -0.06    0.01    0.09    0.40    59 1.03
b2[3]          0.02    0.02  0.16   -0.24   -0.06    0.02    0.10    0.42    59 1.03
c[1,1]        -2.43    2.17  7.72  -14.76   -7.90   -3.95    2.13   13.98    13 1.36
c[1,2]        -2.36    2.17  7.72  -14.89   -7.85   -3.87    2.21   14.05    13 1.36
c[1,3]        -2.46    2.17  7.72  -14.83   -7.98   -4.04    2.13   13.95    13 1.36
c[2,1]        -2.37    2.17  7.72  -14.80   -7.87   -3.89    2.23   14.05    13 1.36
c[2,2]        -2.51    2.17  7.73  -14.93   -8.04   -4.02    2.11   14.02    13 1.36
c[2,3]        -2.38    2.17  7.72  -14.79   -7.89   -3.87    2.24   13.83    13 1.36
c[3,1]        -2.41    2.17  7.72  -14.89   -7.89   -3.89    2.23   13.97    13 1.36
c[3,2]        -2.44    2.17  7.72  -14.80   -7.94   -3.92    2.19   14.01    13 1.36
c[3,3]        -2.43    2.17  7.72  -14.75   -7.92   -3.93    2.19   13.97    13 1.36
c[4,1]        -2.44    2.17  7.72  -14.85   -7.94   -3.97    2.17   13.99    13 1.36
c[4,2]        -2.38    2.17  7.72  -14.75   -7.86   -3.88    2.19   14.04    13 1.36
c[4,3]        -2.42    2.17  7.72  -14.80   -7.91   -3.93    2.19   13.95    13 1.36
d[1,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[1,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[1,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[2,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[2,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[2,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[3,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[3,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[3,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[4,1]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[4,2]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
d[4,3]        -0.02    0.02  0.11   -0.22   -0.10   -0.02    0.06    0.21    29 1.11
mu_a1          0.54    0.16  0.85   -0.94   -0.12    0.50    1.13    2.17    30 1.14
mu_a2          0.07    0.03  0.20   -0.34   -0.05    0.07    0.20    0.43    33 1.06
mu_b1          0.25    0.13  0.85   -1.58   -0.27    0.24    0.79    1.87    42 1.11
mu_b2          0.04    0.09  0.98   -1.90   -0.61    0.03    0.68    2.06   125 1.03
mu_c          -0.24    0.22  0.77   -1.48   -0.80   -0.40    0.22    1.40    13 1.36
mu_d          -0.17    0.21  1.14   -2.22   -0.98   -0.22    0.63    2.09    29 1.11
sigma_a1       0.96    0.11  1.92    0.02    0.12    0.34    1.01    5.30   319 1.01
sigma_a2       0.01    0.00  0.03    0.00    0.00    0.01    0.02    0.08   328 1.01
sigma_b1       4.07    0.28  6.07    0.14    1.14    2.19    4.43   20.92   487 1.02
sigma_b2       0.12    0.03  0.30    0.00    0.02    0.04    0.09    0.87   138 1.05
sigma_c        0.16    0.01  0.13    0.02    0.06    0.13    0.22    0.48   232 1.02
sigma_d        0.00    0.00  0.00    0.00    0.00    0.00    0.00    0.01   179 1.01
sigma_y        0.88    0.00  0.02    0.84    0.86    0.87    0.89    0.91   633 1.01

earnings.tar.gz

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 1
  • Comments: 15 (15 by maintainers)

Most upvoted comments

I think that we can close this now?

@neerajprad It turns out that the problem lies at init_params. I used the same init params with Stan and get similar results. In Stan, params is initialized Uniformly in (-2, 2) interval. While here we used initial_trace. The initial_trace makes dependent latent variables such as a1, b1 get wildly initial values. We might consider supporting the same behaviour as in Stan?

The data and model look correct to me. And Stan did use Uniform prior for bounded parameters. I can’t think of any reason for why using Uniform didn’t work. Will take a closer look now. 😃

I think using HalfCauchy(1.) instead of Uniform(0, 100) addresses some of these issues. Results on Numpyro using HalfCauchy with 2000 warmup and 2000 samples. The parameter values are at least in the vicinity of what we get from Stan.

a1
[3.0880034 3.0277355 2.8439357 3.0560315]

a2
[0.06595199 0.0670628  0.07130343 0.06875928]

b1
[2.2987626 1.3927749 1.1088973]

b2
[-0.03455285 -0.01366357 -0.01128363]

c
[[1.038227   1.120141   1.0181044 ]
 [1.1059053  0.97992927 1.0885361 ]
 [1.069844   1.0359606  1.0458062 ]
 [1.0400462  1.1041307  1.0638354 ]]

d
[[0.01215129 0.0131588  0.01172877]
 [0.01290246 0.01140532 0.01308942]
 [0.0125756  0.01269642 0.01297002]
 [0.01225842 0.01312897 0.01268698]]

mu_a1
0.29997757

mu_a2
0.068249255

mu_b1
0.15676731

mu_b2
-0.1279364

mu_c
0.1059214

mu_d
0.12550469

sigma_a1
0.38256466

sigma_a2
0.008176872

sigma_b1
1.0645754

sigma_b2
0.057657257

sigma_c
0.15156718

sigma_d
0.0021564718

sigma_y
0.8757536