Pymc3 produces different results than Stan/NumPyro

PyMC3 produces weird results for this multiple linear regression model. I tried the same model with Stan and Numpyro. The results of both are quite reasonable for the same model. The distribution parameters seem to be same across the three systems. I tried all the different initialization strategies, but that makes no difference in the results.

Do you guys know why this happens? Is this a bug? Or am i doing something wrong in the pymc3 program? e.g. the matrix multiplication?

import pymc3 as pm
import theano
import arviz as az
import numpy as np

data = dict()
data['x'] =theano.shared(np.array([[0.3127,0.7185,0.9494,0.5597,0.7358,0.1210,0.6191,0.5667,0.4503,0.5986],[0.0269,0.7518,0.9182,0.6499,0.4869,0.3073,0.8153,0.6999,0.0878,0.1670],[0.7447,0.5607,0.0876,0.7848,0.5527,0.6284,0.2125,0.6352,0.8686,0.7983],[0.1576,0.0972,0.6508,0.7716,0.6274,0.1363,0.7613,0.2494,0.3428,0.9261],[0.7717,0.4577,0.0172,0.8652,0.8721,0.1606,0.9938,0.0563,0.5016,0.5695],[0.1718,0.2058,0.9203,0.6626,0.5851,0.1158,0.4355,0.0705,0.7436,0.2993],[0.3664,0.5945,0.0679,0.5061,0.9727,0.2606,0.8866,0.7922,0.8299,0.4857],[0.1521,0.6116,0.8895,0.8145,0.8965,0.3734,0.5860,0.2676,0.9310,0.6909],[0.6845,0.6834,0.3010,0.4828,0.8348,0.3919,0.4748,0.6864,0.1595,0.2750],[0.1263,0.8909,0.6362,0.1842,0.4666,0.9637,0.3681,0.4188,0.3020,0.7719]]).astype("float64"))
data['y'] =theano.shared(np.array([14.3773,22.0280,21.5997,24.4789,25.8981,15.0056,20.0997,17.3859,20.7086,22.1846]).astype("float64"))


with pm.Model() as model:
    w=pm.StudentT('w',78.04098510742188,-17.60015869140625,68.0557632446289, shape=(10,))
    b=pm.Normal('b',1.0,10.0)
    obs1=pm.Normal('obs1',theano.tensor.dot(data['x'],w)-b,1.0, observed=data['y'])
    

    with model:
        step = pm.step_methods.hmc.nuts.NUTS(max_treedepth=10,target_accept=0.8,step_scale=0.25)
        samples = pm.sample(draws=1000, chains=4, tune=1000, step=step, init="auto", jitter_max_retries=10)
        print(az.summary(samples))

Output:

         mean     sd   hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
w[0]  -17.453  0.123  -17.691  -17.232      0.002    0.001    5836.0    2917.0    1.0
w[1]  -17.443  0.125  -17.670  -17.202      0.002    0.001    5165.0    3106.0    1.0
w[2]  -17.860  0.129  -18.111  -17.630      0.002    0.001    5638.0    2882.0    1.0
w[3]  -17.533  0.123  -17.764  -17.302      0.002    0.001    5194.0    2780.0    1.0
w[4]  -17.471  0.125  -17.693  -17.234      0.002    0.001    5086.0    3520.0    1.0
w[5]  -17.474  0.121  -17.704  -17.249      0.002    0.001    4492.0    2993.0    1.0
w[6]  -17.548  0.121  -17.758  -17.303      0.002    0.001    5446.0    2905.0    1.0
w[7]  -17.468  0.123  -17.696  -17.238      0.002    0.001    5593.0    3011.0    1.0
w[8]  -17.415  0.129  -17.665  -17.176      0.002    0.001    5372.0    3091.0    1.0
w[9]  -17.421  0.127  -17.652  -17.182      0.002    0.001    4922.0    3129.0    1.0
b    -112.547  0.388 -113.228 -111.780      0.006    0.004    4087.0    3309.0    1.0 

Stan Program:

data{
matrix[10,10] x;
real weight[10];
vector[10] y;
}
parameters{
vector[10] w;
real b;
}
model{
w~student_t(78.04098510742188,-17.60015869140625,68.0557632446289);
b~normal(1.0,10.0);
y~normal(x*w-b,1.0);
}

Stan output:

                 Mean     MCSE   StdDev     5%    50%    95%    N_Eff  N_Eff/s    R_hat
w[1]               10  3.1e-01  8.4e+00   -3.5     11     24  7.5e+02  4.5e+02  1.0e+00
w[2]              -16  2.5e-01  7.2e+00    -28    -16   -4.2  8.1e+02  4.9e+02  1.0e+00
w[3]              4.9  2.5e-01  6.8e+00   -6.1    5.0     16  7.6e+02  4.6e+02  1.0e+00
w[4]              4.9  1.4e-01  4.2e+00   -2.0    4.9     12  9.0e+02  5.5e+02  1.0e+00
w[5]             -1.6  9.9e-02  3.5e+00   -7.2   -1.6    4.4  1.3e+03  7.8e+02  1.0e+00
w[6]               24  2.7e-01  7.8e+00     12     24     37  8.0e+02  4.9e+02  1.0e+00
w[7]               19  2.9e-01  7.7e+00    6.5     20     32  7.2e+02  4.4e+02  1.0e+00
w[8]              4.2  1.1e-01  3.3e+00   -1.0    4.2    9.5  9.0e+02  5.5e+02  1.0e+00
w[9]             -3.9  1.0e-01  3.0e+00   -8.7   -3.8    1.2  8.0e+02  4.8e+02  1.0e+00
w[10]           -0.48  5.7e-02  2.0e+00   -3.7  -0.47    2.8  1.2e+03  7.3e+02  1.0e+00
b                -1.2  3.3e-01  9.5e+00    -16   -1.2     15  8.4e+02  5.1e+02  1.0e+00

NumPyro Program:


import jax.random
import jax.numpy as jnp
import numpyro
import numpy as np
import numpyro.distributions as dist

from numpyro.infer import MCMC, NUTS, HMC
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
numpyro.enable_x64()


data = dict()
data['x'] =np.array([[0.3127,0.7185,0.9494,0.5597,0.7358,0.1210,0.6191,0.5667,0.4503,0.5986],[0.0269,0.7518,0.9182,0.6499,0.4869,0.3073,0.8153,0.6999,0.0878,0.1670],[0.7447,0.5607,0.0876,0.7848,0.5527,0.6284,0.2125,0.6352,0.8686,0.7983],[0.1576,0.0972,0.6508,0.7716,0.6274,0.1363,0.7613,0.2494,0.3428,0.9261],[0.7717,0.4577,0.0172,0.8652,0.8721,0.1606,0.9938,0.0563,0.5016,0.5695],[0.1718,0.2058,0.9203,0.6626,0.5851,0.1158,0.4355,0.0705,0.7436,0.2993],[0.3664,0.5945,0.0679,0.5061,0.9727,0.2606,0.8866,0.7922,0.8299,0.4857],[0.1521,0.6116,0.8895,0.8145,0.8965,0.3734,0.5860,0.2676,0.9310,0.6909],[0.6845,0.6834,0.3010,0.4828,0.8348,0.3919,0.4748,0.6864,0.1595,0.2750],[0.1263,0.8909,0.6362,0.1842,0.4666,0.9637,0.3681,0.4188,0.3020,0.7719]])
data['weight'] =np.array([3.9238,3.1421,6.5905,2.8718,1.7819,1.4975,0.1834,7.6811,5.0964,2.7595])
data['y'] =np.array([14.3773,22.0280,21.5997,24.4789,25.8981,15.0056,20.0997,17.3859,20.7086,22.1846])

def model():
    w=numpyro.sample("w", dist.StudentT(df=78.04098510742188*np.ones([10]),loc=-17.60015869140625*np.ones([10]),scale=68.0557632446289*np.ones([10])))
    b=numpyro.sample("b", dist.Normal(1.0,10.0))
    with numpyro.plate("size", np.size(data['y'])):
        numpyro.sample("obs27", dist.Normal(jnp.matmul(data['x'],w)-b,1.0), obs=data['y'])
       
mcmc = MCMC(numpyro.infer.NUTS(model),num_samples=1000,num_warmup=1000,num_chains=4)
mcmc.run(jax.random.PRNGKey(2080339834))
params = mcmc.get_samples(group_by_chain=True)
mcmc.print_summary()

NumPyro output:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b     -1.69      9.71     -1.69    -17.84     13.39    837.72      1.00
      w[0]      9.96      8.76      9.98     -3.39     24.46    773.45      1.00
      w[1]    -15.62      7.25    -15.68    -27.79     -4.56    816.07      1.00
      w[2]      4.53      6.98      4.56     -6.41     15.65    773.57      1.00
      w[3]      5.05      4.34      5.07     -1.90     12.12    894.61      1.00
      w[4]     -1.49      3.60     -1.48     -7.36      4.41   1413.75      1.00
      w[5]     23.85      7.93     23.81     11.97     37.11    831.59      1.00
      w[6]     18.87      8.01     18.98      6.06     31.51    777.94      1.00
      w[7]      4.07      3.26      4.10     -1.31      9.27    881.17      1.00
      w[8]     -4.01      3.05     -4.00     -8.93      0.90    881.75      1.00
      w[9]     -0.50      1.98     -0.54     -3.83      2.62   1546.71      1.00
1 Like

Just to rule out something obvious, do you get the same weird results if you just run with vanilla pm.sample, instead of trying to initialize NUTS manually?

Yes, same results

Does the sampler give any indication that it had problems sampling? The results look really wrong.

Can you try with
pm.StudentT('w',78.04098510742188,-17.60015869140625,sigma=68.0557632446289, shape=(10,))

2 Likes

Specifically, I would explicitly name each argument (and generally recommend doing so):

w=pm.StudentT('w',
              nu=78.04098510742188,
              mu=-17.60015869140625,
              sigma=68.0557632446289,
              shape=(10,))

That gets me results that are pretty close to those you are getting from Stan.

1 Like

Yes, looks like it was assuming the third parameter is lam instead of sigma. Changing that improves the result but still a bit different than stan. I am seeing some convergence issues. I tried upto 3000 warmup iterations and some different initializations. Still same thing. Any suggestions?

I checked that If i start with start={'b': -1.2}, then it converges.

WARNING (theano.tensor.blas): Using NumPy C-API based implementation for BLAS functions.                                                                                                                         Multiprocess sampling (4 chains in 4 jobs)
NUTS: [b, w]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 30 seconds.
There were 977 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.48232653725158614, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8824357326401848, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.
The estimated number of effective samples is smaller than 200 for some parameters.

        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
w[0]  11.310  7.584  -3.291   26.709      0.299    0.212     599.0    1098.0   1.17
w[1] -17.178  6.528 -29.643   -4.340      0.510    0.362     136.0    1135.0   1.16
w[2]   5.728  6.095  -6.222   17.727      0.247    0.175     569.0    1042.0   1.18
w[3]   4.186  4.133  -2.835   12.340      0.541    0.384      62.0    1225.0   1.04
w[4]  -1.028  3.401  -7.670    4.739      0.669    0.479      26.0     896.0   1.10
w[5]  25.289  6.999  11.551   38.553      0.373    0.264     357.0     879.0   1.14
w[6]  20.020  6.852   6.446   33.544      0.252    0.178     701.0     925.0   1.18
w[7]   4.631  2.951  -1.629    9.870      0.157    0.111     317.0    1258.0   1.08
w[8]  -3.700  2.586  -8.616    1.778      0.082    0.074     888.0    1305.0   1.17
w[9]  -0.751  1.833  -4.126    3.072      0.120    0.085     289.0    2021.0   1.11
b     -0.181  8.509 -17.118   16.054      0.329    0.857     631.0    1140.0   1.14   

I am getting no divergences and seemingly good convergence (note the \hat{r}). I just cut and pasted your code. I am running pymc v3.11.2 and theano-pymc v1.1.2.

Multiprocess sampling (4 chains in 2 jobs)
NUTS: [b, w]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 40 seconds.
The acceptance probability does not match the target. It is 0.8817572415352573, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8824272675568593, but should be close to 0.8. Try to increase the number of tuning steps.
The number of effective samples is smaller than 25% for some parameters.
        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
w[0]  10.514  8.580  -5.108   27.238      0.358    0.256     578.0     754.0    1.0
w[1] -16.138  7.215 -29.167   -1.851      0.299    0.212     585.0     735.0    1.0
w[2]   4.954  6.910  -7.890   18.009      0.289    0.205     572.0     723.0    1.0
w[3]   4.810  4.263  -3.699   12.306      0.158    0.112     733.0    1192.0    1.0
w[4]  -1.607  3.621  -8.023    5.442      0.091    0.065    1574.0    2033.0    1.0
w[5]  24.351  7.859   9.767   39.080      0.321    0.227     603.0     808.0    1.0
w[6]  19.428  7.858   4.582   33.896      0.328    0.233     577.0     729.0    1.0
w[7]   4.293  3.254  -1.798   10.347      0.126    0.089     667.0     956.0    1.0
w[8]  -3.805  3.015  -9.293    2.117      0.117    0.083     670.0     937.0    1.0
w[9]  -0.565  1.997  -4.291    3.273      0.061    0.043    1104.0    1689.0    1.0
b     -1.092  9.642 -19.659   16.161      0.385    0.273     628.0     848.0    1.0
1 Like

Interestingly, it looks like the non-convergence happens if i set this seed:
np.random.seed(2008007173) at the beginning.

If i remove that, then it converges. Probably, in some rare executions it does not converge.