Convergence issues with LogNormal

I observed when using LogNormal distribution with large sigma, NUTS has trouble converging.
I tried different initializations but doesn’t seem to change results. If I change mu to -2, then it converges to correct result. Any ideas what is wrong here?

I get correct results with stan and numpyro with the same model.

import pymc3 as pm
import theano
import arviz as az
import numpy as np


data = dict()
data['x'] =theano.shared(np.array([[0.3127,0.7185,0.9494,0.5597,0.7358,0.1210,0.6191,0.5667,0.4503,0.5986],[0.0269,0.7518,0.9182,0.6499,0.4869,0.3073,0.8153,0.6999,0.0878,0.1670],[0.7447,0.5607,0.0876,0.7848,0.5527,0.6284,0.2125,0.6352,0.8686,0.7983],[0.1576,0.0972,0.6508,0.7716,0.6274,0.1363,0.7613,0.2494,0.3428,0.9261],[0.7717,0.4577,0.0172,0.8652,0.8721,0.1606,0.9938,0.0563,0.5016,0.5695],[0.1718,0.2058,0.9203,0.6626,0.5851,0.1158,0.4355,0.0705,0.7436,0.2993],[0.3664,0.5945,0.0679,0.5061,0.9727,0.2606,0.8866,0.7922,0.8299,0.4857],[0.1521,0.6116,0.8895,0.8145,0.8965,0.3734,0.5860,0.2676,0.9310,0.6909],[0.6845,0.6834,0.3010,0.4828,0.8348,0.3919,0.4748,0.6864,0.1595,0.2750],[0.1263,0.8909,0.6362,0.1842,0.4666,0.9637,0.3681,0.4188,0.3020,0.7719]]).astype("float64"))

data['y'] =theano.shared(np.array([14.3773,22.0280,21.5997,24.4789,25.8981,15.0056,20.0997,17.3859,20.7086,22.1846]).astype("float64"))


with pm.Model() as model:
    w_param_1=pm.Lognormal('w_param_1',mu=-22.047279357910156,sigma=25.498497009277344)
    w_param_0=pm.InverseGamma('w_param_0',alpha=17.76064682006836,beta=32.966636657714844)
    w=pm.Normal('w',w_param_0,w_param_1, shape=(10,))
    b=pm.Normal('b',1.0,10.0)
    obs1=pm.Normal('obs1',theano.tensor.dot(data['x'],w)+b,1.0, observed=data['y'])
    

    with model:
        step = pm.step_methods.hmc.nuts.NUTS(max_treedepth=10,target_accept=0.8,step_scale=0.25)
        samples = pm.sample(draws=1000, chains=4, tune=1000, step=step, init="auto", jitter_max_retries=10)
        print(az.summary(samples))

Output:

NUTS: [b, w, w_param_0, w_param_1]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 94 seconds.
The acceptance probability does not match the target. It is 0.9492248356913224, but should be close to 0.8. Try to increase the number of tuning steps.                                                          There were 559 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.3781022302663828, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9912036985866993, but should be close to 0.8. Try to increase the number of tuning steps.
There were 4 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.9714220559280048, but should be close to 0.8. Try to increase the number of tuning steps.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.
             mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
w[0]        0.641  2.708  -5.322    5.370      0.440    0.314      29.0     225.0   1.70
w[1]       -1.517  3.947  -8.940    1.802      1.601    1.191       6.0     132.0   2.02
w[2]       -1.360  3.630  -7.746    1.792      1.520    1.136       6.0     197.0   2.06
w[3]        4.632  3.526   1.255   10.937      1.501    1.123       5.0      13.0   2.02
w[4]        0.100  2.510  -5.863    2.587      0.789    0.575      11.0     111.0   1.70
w[5]        6.366  5.408   1.248   15.659      2.364    1.775       5.0      12.0   2.13
w[6]        4.975  4.082   1.231   12.371      1.671    1.245       6.0      15.0   2.08
w[7]        0.796  1.596  -2.893    2.935      0.430    0.311      13.0     139.0   1.55
w[8]       -2.383  4.227  -8.455    1.791      2.020    1.536       5.0      57.0   2.18
w[9]        1.752  1.215  -0.814    4.463      0.026    0.125      88.0     140.0   1.92
b          13.342  3.916   8.121   22.311      0.923    0.684      17.0     148.0   1.40
w_param_1   3.466  3.795   0.000    9.498      1.736    1.311       5.0      14.0   2.26
w_param_0   1.784  0.353   1.198    2.470      0.078    0.056      16.0      43.0   1.66      

Stan Program:

data{
matrix[10,10] x;
real weight[10];
vector[10] y;
}
parameters{
vector[10] w;
real b;
real<lower=0> w_param_0;
real<lower=0> w_param_1;
}
model{
w_param_1~lognormal(-22.047279357910156,25.498497009277344);
w_param_0~inv_gamma(17.76064682006836,32.966636657714844);
w~normal(w_param_0,w_param_1);
b~normal(1.0,10.0);
y~normal(x*w+b,1.0);
}

Stan Output:

                    Mean     MCSE   StdDev     5%    50%    95%    N_Eff  N_Eff/s    R_hat
w[1]            -4.4e-01  1.1e-01  3.6e+00   -6.3  -0.52    5.7  1.1e+03  1.0e+03  1.0e+00
w[2]            -4.8e+00  9.1e-02  3.4e+00    -11   -4.6   0.57  1.4e+03  1.4e+03  1.0e+00
w[3]            -4.5e+00  8.7e-02  2.9e+00   -9.0   -4.5   0.33  1.1e+03  1.1e+03  1.0e+00
w[4]             7.6e+00  6.3e-02  2.7e+00    3.4    7.5     12  1.8e+03  1.8e+03  1.0e+00
w[5]            -1.2e+00  6.1e-02  2.8e+00   -5.9   -1.3    3.4  2.1e+03  2.1e+03  1.0e+00
w[6]             1.1e+01  1.0e-01  3.8e+00    5.4     11     18  1.3e+03  1.3e+03  1.0e+00
w[7]             8.2e+00  1.0e-01  3.4e+00    3.0    8.1     14  1.1e+03  1.1e+03  1.0e+00
w[8]            -9.1e-02  4.4e-02  1.8e+00   -3.1  -0.11    3.0  1.8e+03  1.7e+03  1.0e+00
w[9]            -6.5e+00  4.9e-02  1.8e+00   -9.4   -6.4   -3.5  1.4e+03  1.3e+03  1.0e+00
w[10]            1.9e+00  3.2e-02  1.6e+00  -0.89    1.9    4.5  2.7e+03  2.7e+03  1.0e+00
b                1.5e+01  1.2e-01  4.6e+00    7.1     15     22  1.4e+03  1.4e+03  1.0e+00
w_param_0        1.9e+00  7.3e-03  4.3e-01    1.3    1.8    2.7  3.6e+03  3.5e+03  1.0e+00
w_param_1        6.9e+00  4.6e-02  2.1e+00    4.2    6.5     11  2.1e+03  2.0e+03  1.0e+00

NumPyro Output:

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
          b     15.02      4.89     15.25      7.22     22.82    905.11      1.00
       w[0]     -0.33      3.74     -0.47     -6.30      5.59    898.13      1.00
       w[1]     -4.81      3.54     -4.51    -10.82      0.72    866.61      1.01
       w[2]     -4.35      2.98     -4.54     -9.31      0.27    812.20      1.01
       w[3]      7.58      2.74      7.52      3.22     12.25   1696.11      1.00
       w[4]     -1.28      2.86     -1.30     -6.14      3.29   2122.09      1.00
       w[5]     11.20      3.99     10.86      4.72     17.53    797.94      1.00
       w[6]      8.35      3.50      8.07      3.03     14.42    762.38      1.01
       w[7]     -0.03      1.91     -0.10     -3.11      3.05   1245.41      1.00
       w[8]     -6.45      1.78     -6.45     -9.35     -3.52   1383.19      1.00
       w[9]      1.83      1.70      1.85     -1.03      4.53   1988.71      1.00
  w_param_0      1.91      0.45      1.85      1.22      2.60   2835.27      1.00
  w_param_1      6.97      2.22      6.55      3.65     10.08   1358.71      1.00
                                                                                     

How many samples are drawn and used for tuning in the other samplers? Looking quickly, the smallest posterior mean of your “w”s is ~ -6.5, which is much higher than -22.

My first guess: I suspect that 1000 samples and 1000 for tuning need to be increased since the prior seems far off from the posterior (as the errors are suggesting). This aligns with the fact that using -2 as your prior yields better results since it’s closer to the posterior means of your “w” parameters. Let me know if this helps (or not so much)!

1 Like

I use 1000 tuning/1000 sampling iterations and 4 chains in all 3 systems.

I tried with 2000 tuning samples with pymc3. The results get better but I see some divergences sometimes.

             mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
w[0]       -0.085  3.621  -7.281    6.183      0.122    0.105     993.0     975.0   1.01
w[1]       -4.917  3.438 -11.145    1.601      0.113    0.093    1058.0    1062.0   1.00
w[2]       -4.177  2.849  -9.466    1.148      0.099    0.070     957.0     851.0   1.01
w[3]        7.440  2.749   2.189   12.611      0.066    0.048    1733.0    2008.0   1.00
w[4]       -1.425  2.922  -7.102    3.886      0.063    0.048    2130.0    2612.0   1.00
w[5]       11.256  3.858   4.259   18.376      0.128    0.102    1058.0     998.0   1.00
w[6]        8.562  3.412   1.900   14.690      0.118    0.094     969.0     857.0   1.01
w[7]        0.021  1.867  -3.380    3.591      0.050    0.039    1481.0    1218.0   1.00
w[8]       -6.342  1.729  -9.586   -3.130      0.047    0.033    1395.0    1375.0   1.01
w[9]        1.902  1.682  -1.062    5.163      0.038    0.027    1966.0    2294.0   1.00
b          14.815  4.701   6.633   23.998      0.144    0.102    1211.0    1071.0   1.00
w_param_1   6.894  2.078   3.730   10.889      0.053    0.041    1862.0    1738.0   1.00
w_param_0   1.919  0.464   1.163    2.807      0.008    0.006    3397.0    2546.0   1.00