I observed when using LogNormal distribution with large sigma, NUTS has trouble converging.
I tried different initializations but doesn’t seem to change results. If I change mu to -2, then it converges to correct result. Any ideas what is wrong here?
I get correct results with stan and numpyro with the same model.
import pymc3 as pm
import theano
import arviz as az
import numpy as np
data = dict()
data['x'] =theano.shared(np.array([[0.3127,0.7185,0.9494,0.5597,0.7358,0.1210,0.6191,0.5667,0.4503,0.5986],[0.0269,0.7518,0.9182,0.6499,0.4869,0.3073,0.8153,0.6999,0.0878,0.1670],[0.7447,0.5607,0.0876,0.7848,0.5527,0.6284,0.2125,0.6352,0.8686,0.7983],[0.1576,0.0972,0.6508,0.7716,0.6274,0.1363,0.7613,0.2494,0.3428,0.9261],[0.7717,0.4577,0.0172,0.8652,0.8721,0.1606,0.9938,0.0563,0.5016,0.5695],[0.1718,0.2058,0.9203,0.6626,0.5851,0.1158,0.4355,0.0705,0.7436,0.2993],[0.3664,0.5945,0.0679,0.5061,0.9727,0.2606,0.8866,0.7922,0.8299,0.4857],[0.1521,0.6116,0.8895,0.8145,0.8965,0.3734,0.5860,0.2676,0.9310,0.6909],[0.6845,0.6834,0.3010,0.4828,0.8348,0.3919,0.4748,0.6864,0.1595,0.2750],[0.1263,0.8909,0.6362,0.1842,0.4666,0.9637,0.3681,0.4188,0.3020,0.7719]]).astype("float64"))
data['y'] =theano.shared(np.array([14.3773,22.0280,21.5997,24.4789,25.8981,15.0056,20.0997,17.3859,20.7086,22.1846]).astype("float64"))
with pm.Model() as model:
w_param_1=pm.Lognormal('w_param_1',mu=-22.047279357910156,sigma=25.498497009277344)
w_param_0=pm.InverseGamma('w_param_0',alpha=17.76064682006836,beta=32.966636657714844)
w=pm.Normal('w',w_param_0,w_param_1, shape=(10,))
b=pm.Normal('b',1.0,10.0)
obs1=pm.Normal('obs1',theano.tensor.dot(data['x'],w)+b,1.0, observed=data['y'])
with model:
step = pm.step_methods.hmc.nuts.NUTS(max_treedepth=10,target_accept=0.8,step_scale=0.25)
samples = pm.sample(draws=1000, chains=4, tune=1000, step=step, init="auto", jitter_max_retries=10)
print(az.summary(samples))
Output:
NUTS: [b, w, w_param_0, w_param_1]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 94 seconds.
The acceptance probability does not match the target. It is 0.9492248356913224, but should be close to 0.8. Try to increase the number of tuning steps. There were 559 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.3781022302663828, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9912036985866993, but should be close to 0.8. Try to increase the number of tuning steps.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.9714220559280048, but should be close to 0.8. Try to increase the number of tuning steps.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
w[0] 0.641 2.708 -5.322 5.370 0.440 0.314 29.0 225.0 1.70
w[1] -1.517 3.947 -8.940 1.802 1.601 1.191 6.0 132.0 2.02
w[2] -1.360 3.630 -7.746 1.792 1.520 1.136 6.0 197.0 2.06
w[3] 4.632 3.526 1.255 10.937 1.501 1.123 5.0 13.0 2.02
w[4] 0.100 2.510 -5.863 2.587 0.789 0.575 11.0 111.0 1.70
w[5] 6.366 5.408 1.248 15.659 2.364 1.775 5.0 12.0 2.13
w[6] 4.975 4.082 1.231 12.371 1.671 1.245 6.0 15.0 2.08
w[7] 0.796 1.596 -2.893 2.935 0.430 0.311 13.0 139.0 1.55
w[8] -2.383 4.227 -8.455 1.791 2.020 1.536 5.0 57.0 2.18
w[9] 1.752 1.215 -0.814 4.463 0.026 0.125 88.0 140.0 1.92
b 13.342 3.916 8.121 22.311 0.923 0.684 17.0 148.0 1.40
w_param_1 3.466 3.795 0.000 9.498 1.736 1.311 5.0 14.0 2.26
w_param_0 1.784 0.353 1.198 2.470 0.078 0.056 16.0 43.0 1.66
Stan Program:
data{
matrix[10,10] x;
real weight[10];
vector[10] y;
}
parameters{
vector[10] w;
real b;
real<lower=0> w_param_0;
real<lower=0> w_param_1;
}
model{
w_param_1~lognormal(-22.047279357910156,25.498497009277344);
w_param_0~inv_gamma(17.76064682006836,32.966636657714844);
w~normal(w_param_0,w_param_1);
b~normal(1.0,10.0);
y~normal(x*w+b,1.0);
}
Stan Output:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
w[1] -4.4e-01 1.1e-01 3.6e+00 -6.3 -0.52 5.7 1.1e+03 1.0e+03 1.0e+00
w[2] -4.8e+00 9.1e-02 3.4e+00 -11 -4.6 0.57 1.4e+03 1.4e+03 1.0e+00
w[3] -4.5e+00 8.7e-02 2.9e+00 -9.0 -4.5 0.33 1.1e+03 1.1e+03 1.0e+00
w[4] 7.6e+00 6.3e-02 2.7e+00 3.4 7.5 12 1.8e+03 1.8e+03 1.0e+00
w[5] -1.2e+00 6.1e-02 2.8e+00 -5.9 -1.3 3.4 2.1e+03 2.1e+03 1.0e+00
w[6] 1.1e+01 1.0e-01 3.8e+00 5.4 11 18 1.3e+03 1.3e+03 1.0e+00
w[7] 8.2e+00 1.0e-01 3.4e+00 3.0 8.1 14 1.1e+03 1.1e+03 1.0e+00
w[8] -9.1e-02 4.4e-02 1.8e+00 -3.1 -0.11 3.0 1.8e+03 1.7e+03 1.0e+00
w[9] -6.5e+00 4.9e-02 1.8e+00 -9.4 -6.4 -3.5 1.4e+03 1.3e+03 1.0e+00
w[10] 1.9e+00 3.2e-02 1.6e+00 -0.89 1.9 4.5 2.7e+03 2.7e+03 1.0e+00
b 1.5e+01 1.2e-01 4.6e+00 7.1 15 22 1.4e+03 1.4e+03 1.0e+00
w_param_0 1.9e+00 7.3e-03 4.3e-01 1.3 1.8 2.7 3.6e+03 3.5e+03 1.0e+00
w_param_1 6.9e+00 4.6e-02 2.1e+00 4.2 6.5 11 2.1e+03 2.0e+03 1.0e+00
NumPyro Output:
mean std median 5.0% 95.0% n_eff r_hat
b 15.02 4.89 15.25 7.22 22.82 905.11 1.00
w[0] -0.33 3.74 -0.47 -6.30 5.59 898.13 1.00
w[1] -4.81 3.54 -4.51 -10.82 0.72 866.61 1.01
w[2] -4.35 2.98 -4.54 -9.31 0.27 812.20 1.01
w[3] 7.58 2.74 7.52 3.22 12.25 1696.11 1.00
w[4] -1.28 2.86 -1.30 -6.14 3.29 2122.09 1.00
w[5] 11.20 3.99 10.86 4.72 17.53 797.94 1.00
w[6] 8.35 3.50 8.07 3.03 14.42 762.38 1.01
w[7] -0.03 1.91 -0.10 -3.11 3.05 1245.41 1.00
w[8] -6.45 1.78 -6.45 -9.35 -3.52 1383.19 1.00
w[9] 1.83 1.70 1.85 -1.03 4.53 1988.71 1.00
w_param_0 1.91 0.45 1.85 1.22 2.60 2835.27 1.00
w_param_1 6.97 2.22 6.55 3.65 10.08 1358.71 1.00