numpyro: NUTS sometimes does not converge on a regression model
I observed that NUTS sometimes does not converge on this model (depending on the rng seed). For instance, if I change the seed to 7227991, it converges. Overall it fails to converge about 2-3 times out of 10. I am guessing that its not choosing the initial point properly for sigma/w sometimes and hence the chain diverges?
The model seems to work fine with Stan though, i.e. converges and gives correct results.
Is this expected?
import jax.random
import numpyro
import numpy as np
import numpyro.distributions as dist
from numpyro.infer import MCMC
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
import os
import random
import sys
data = dict()
data['x'] =np.array([0.0970,2.1020,0.5840,1.0394,3.4375,1.3102,1.2863,4.5382,1.2539,2.9319,4.7777,4.5937,4.0403,0.7749,1.8342,1.5008,3.9557,1.6095,0.5602,4.7997,1.6436,4.5236,1.9404,1.1508,3.1447,3.3551,0.2669,3.2483,2.5293,0.1596,2.8992,2.5016,4.7589,2.5648,1.9871,0.3201,4.4863,0.5675,0.8897,0.5689,0.5621,0.5701,0.7998,0.3166,4.8490,2.7849,2.6858,4.2393,1.4939,1.7230,0.4202,0.3591,2.5912,3.9753,3.5417,0.3198,2.7192,1.9892,1.9525,4.7293,1.7719,3.7889,2.3613,4.6994,2.0155,3.8840,0.8289,2.1164,1.3342,1.4850,0.5512,0.3919,2.4001,3.1949,4.9759,1.4158,0.6088,2.1599,0.3643,0.6800,2.4031,1.7706,3.7669,4.1080,4.9707,1.9139,4.7615,3.3075,3.6872,4.0946,0.3447,2.7292,0.2768,3.7641,3.3987,0.1833,4.9353,1.9730,2.2034,2.8598])
data['y'] =np.array([4.9404,45.0399,14.6792,23.7890,71.7504,29.2036,28.7251,93.7642,28.0774,61.6389,98.5535,94.8738,83.8054,18.4976,39.6836,33.0162,82.1149,35.1902,14.2033,98.9948,35.8724,93.4728,41.8081,26.0160,65.8950,70.1028,8.3385,67.9659,53.5864,6.1922,60.9839,53.0324,98.1783,54.2960,42.7421,9.4019,92.7252,14.3490,20.7931,14.3781,14.2424,14.4019,18.9951,9.3322,99.9797,58.6983,56.7166,87.7854,32.8781,37.4594,11.4034,10.1829,54.8241,82.5061,73.8340,9.3964,57.3845,42.7835,42.0501,97.5864,38.4374,78.7779,50.2264,96.9873,43.3109,80.6796,19.5773,45.3275,29.6839,32.6994,14.0248,10.8371,51.0026,66.8986,102.5188,31.3158,15.1756,46.1974,10.2855,16.6008,51.0629,38.4128,78.3385,85.1597,102.4138,41.2770,98.2303,69.1497,76.7434,84.8919,9.8933,57.5845,8.5367,78.2816,70.9734,6.6665,101.7066,42.4600,47.0682,60.1959])
def model():
w=numpyro.sample("w", dist.Exponential(87.2889404296875))
b=numpyro.sample("b", dist.Normal(1,1))
sigma=numpyro.sample("sigma", dist.Chi2(31.628190994262695))
with numpyro.plate("size", np.size(data['y'])):
numpyro.sample("obs55", dist.Normal(w*data['x']+b,sigma), obs=data['y'])
mcmc = MCMC(numpyro.infer.NUTS(model),num_samples=1000,num_warmup=1000,num_chains=4)
mcmc.run(jax.random.PRNGKey(323728029))
mcmc.print_summary()
Output:
mean std median 5.0% 95.0% n_eff r_hat
b 2.75 0.89 3.00 0.98 4.05 1438.28 1.01
sigma 39.80 23.18 51.43 0.00 56.96 2.03 8.17
w 5.01 8.66 0.01 0.00 20.00 2.00 890.88
Environment:
Python 3.7
Numpyro 0.9.0
Ubuntu 18.04
Stan output:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -424 2.8e-02 1.2 -4.3e+02 -4.2e+02 -422 2.0e+03 1.4e+04 1.0e+00
w 0.012 2.1e-04 0.013 6.2e-04 8.4e-03 0.038 3.9e+03 2.7e+04 1.0e+00
b 2.7 2.0e-02 1.0 9.6e-01 2.6e+00 4.3 2.5e+03 1.7e+04 1.0e+00
sigma 53 6.8e-02 3.4 4.8e+01 5.3e+01 59 2.6e+03 1.8e+04 1.0e+00
Please let me know if these reports are useful. I am testing numpyro against various models to check for bugs. Since i am new to numpyro, some of these observations maybe already be known (i.e., not a bug) to developers.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 20 (10 by maintainers)
Glad that you figured it out! I guess we can close this issue? Just a comment that it is better to reparam your model
I don’t know. I guess so. Stan might generate a collection of init values and choose the best one to start with.
I’m not sure what would be a reasonable initialization strategy. In the above model, chains are not mixing and to me, it could be a good signal to explore more, rather than ignore (which might omit areas that has small probability). Typically, I rarely use strong priors and will try to scale it to unit scale: e.g.
w = numpyro.sample("w_base", dist.Exponential(1)) / 87.2889404296875.I guess it is related to please post it here.
The problem is with the specification of your model rather than NumPyro. Eyeballing the data it looks like there is very strong signal and very little noise
I estimate the true values of
wandbthat you used to generate the data were20and3respectively.Your prior on
whowever (dist.Exponential(87.2889404296875)) is extremely concentrated near0. This is like you telling the model it’s overwhelmingly likely the true value ofwis in the range0-0.1. Several orders of magnitude smaller than the true value of20.Running the model with different seeds leads to R hat statistics close to
1, but it doesn’t seem to me the model is doing the right thing. In fact your R hat is one because all chains are stuck at0(where all the prior probability is) and can’t make the jump to20(where the likelihood is large).Seems from the Stan output that the same is happening there, all chains are getting stuck in the same place rather than converging to the right answer. The NumPyro summary you posted shows an average
wof5, which is because one chain makes it to20, the other three get stuck at0.You can get a lot more help with modelling questions on the Pyro forum.