PyMC3 produces weird results for this multiple linear regression model. I tried the same model with Stan and Numpyro. The results of both are quite reasonable for the same model. The distribution parameters seem to be same across the three systems. I tried all the different initialization strategies, but that makes no difference in the results.
Do you guys know why this happens? Is this a bug? Or am i doing something wrong in the pymc3 program? e.g. the matrix multiplication?
import pymc3 as pm
import theano
import arviz as az
import numpy as np
data = dict()
data['x'] =theano.shared(np.array([[0.3127,0.7185,0.9494,0.5597,0.7358,0.1210,0.6191,0.5667,0.4503,0.5986],[0.0269,0.7518,0.9182,0.6499,0.4869,0.3073,0.8153,0.6999,0.0878,0.1670],[0.7447,0.5607,0.0876,0.7848,0.5527,0.6284,0.2125,0.6352,0.8686,0.7983],[0.1576,0.0972,0.6508,0.7716,0.6274,0.1363,0.7613,0.2494,0.3428,0.9261],[0.7717,0.4577,0.0172,0.8652,0.8721,0.1606,0.9938,0.0563,0.5016,0.5695],[0.1718,0.2058,0.9203,0.6626,0.5851,0.1158,0.4355,0.0705,0.7436,0.2993],[0.3664,0.5945,0.0679,0.5061,0.9727,0.2606,0.8866,0.7922,0.8299,0.4857],[0.1521,0.6116,0.8895,0.8145,0.8965,0.3734,0.5860,0.2676,0.9310,0.6909],[0.6845,0.6834,0.3010,0.4828,0.8348,0.3919,0.4748,0.6864,0.1595,0.2750],[0.1263,0.8909,0.6362,0.1842,0.4666,0.9637,0.3681,0.4188,0.3020,0.7719]]).astype("float64"))
data['y'] =theano.shared(np.array([14.3773,22.0280,21.5997,24.4789,25.8981,15.0056,20.0997,17.3859,20.7086,22.1846]).astype("float64"))
with pm.Model() as model:
w=pm.StudentT('w',78.04098510742188,-17.60015869140625,68.0557632446289, shape=(10,))
b=pm.Normal('b',1.0,10.0)
obs1=pm.Normal('obs1',theano.tensor.dot(data['x'],w)-b,1.0, observed=data['y'])
with model:
step = pm.step_methods.hmc.nuts.NUTS(max_treedepth=10,target_accept=0.8,step_scale=0.25)
samples = pm.sample(draws=1000, chains=4, tune=1000, step=step, init="auto", jitter_max_retries=10)
print(az.summary(samples))
Output:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
w[0] -17.453 0.123 -17.691 -17.232 0.002 0.001 5836.0 2917.0 1.0
w[1] -17.443 0.125 -17.670 -17.202 0.002 0.001 5165.0 3106.0 1.0
w[2] -17.860 0.129 -18.111 -17.630 0.002 0.001 5638.0 2882.0 1.0
w[3] -17.533 0.123 -17.764 -17.302 0.002 0.001 5194.0 2780.0 1.0
w[4] -17.471 0.125 -17.693 -17.234 0.002 0.001 5086.0 3520.0 1.0
w[5] -17.474 0.121 -17.704 -17.249 0.002 0.001 4492.0 2993.0 1.0
w[6] -17.548 0.121 -17.758 -17.303 0.002 0.001 5446.0 2905.0 1.0
w[7] -17.468 0.123 -17.696 -17.238 0.002 0.001 5593.0 3011.0 1.0
w[8] -17.415 0.129 -17.665 -17.176 0.002 0.001 5372.0 3091.0 1.0
w[9] -17.421 0.127 -17.652 -17.182 0.002 0.001 4922.0 3129.0 1.0
b -112.547 0.388 -113.228 -111.780 0.006 0.004 4087.0 3309.0 1.0
Stan Program:
data{
matrix[10,10] x;
real weight[10];
vector[10] y;
}
parameters{
vector[10] w;
real b;
}
model{
w~student_t(78.04098510742188,-17.60015869140625,68.0557632446289);
b~normal(1.0,10.0);
y~normal(x*w-b,1.0);
}
Stan output:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
w[1] 10 3.1e-01 8.4e+00 -3.5 11 24 7.5e+02 4.5e+02 1.0e+00
w[2] -16 2.5e-01 7.2e+00 -28 -16 -4.2 8.1e+02 4.9e+02 1.0e+00
w[3] 4.9 2.5e-01 6.8e+00 -6.1 5.0 16 7.6e+02 4.6e+02 1.0e+00
w[4] 4.9 1.4e-01 4.2e+00 -2.0 4.9 12 9.0e+02 5.5e+02 1.0e+00
w[5] -1.6 9.9e-02 3.5e+00 -7.2 -1.6 4.4 1.3e+03 7.8e+02 1.0e+00
w[6] 24 2.7e-01 7.8e+00 12 24 37 8.0e+02 4.9e+02 1.0e+00
w[7] 19 2.9e-01 7.7e+00 6.5 20 32 7.2e+02 4.4e+02 1.0e+00
w[8] 4.2 1.1e-01 3.3e+00 -1.0 4.2 9.5 9.0e+02 5.5e+02 1.0e+00
w[9] -3.9 1.0e-01 3.0e+00 -8.7 -3.8 1.2 8.0e+02 4.8e+02 1.0e+00
w[10] -0.48 5.7e-02 2.0e+00 -3.7 -0.47 2.8 1.2e+03 7.3e+02 1.0e+00
b -1.2 3.3e-01 9.5e+00 -16 -1.2 15 8.4e+02 5.1e+02 1.0e+00
NumPyro Program:
import jax.random
import jax.numpy as jnp
import numpyro
import numpy as np
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
numpyro.enable_x64()
data = dict()
data['x'] =np.array([[0.3127,0.7185,0.9494,0.5597,0.7358,0.1210,0.6191,0.5667,0.4503,0.5986],[0.0269,0.7518,0.9182,0.6499,0.4869,0.3073,0.8153,0.6999,0.0878,0.1670],[0.7447,0.5607,0.0876,0.7848,0.5527,0.6284,0.2125,0.6352,0.8686,0.7983],[0.1576,0.0972,0.6508,0.7716,0.6274,0.1363,0.7613,0.2494,0.3428,0.9261],[0.7717,0.4577,0.0172,0.8652,0.8721,0.1606,0.9938,0.0563,0.5016,0.5695],[0.1718,0.2058,0.9203,0.6626,0.5851,0.1158,0.4355,0.0705,0.7436,0.2993],[0.3664,0.5945,0.0679,0.5061,0.9727,0.2606,0.8866,0.7922,0.8299,0.4857],[0.1521,0.6116,0.8895,0.8145,0.8965,0.3734,0.5860,0.2676,0.9310,0.6909],[0.6845,0.6834,0.3010,0.4828,0.8348,0.3919,0.4748,0.6864,0.1595,0.2750],[0.1263,0.8909,0.6362,0.1842,0.4666,0.9637,0.3681,0.4188,0.3020,0.7719]])
data['weight'] =np.array([3.9238,3.1421,6.5905,2.8718,1.7819,1.4975,0.1834,7.6811,5.0964,2.7595])
data['y'] =np.array([14.3773,22.0280,21.5997,24.4789,25.8981,15.0056,20.0997,17.3859,20.7086,22.1846])
def model():
w=numpyro.sample("w", dist.StudentT(df=78.04098510742188*np.ones([10]),loc=-17.60015869140625*np.ones([10]),scale=68.0557632446289*np.ones([10])))
b=numpyro.sample("b", dist.Normal(1.0,10.0))
with numpyro.plate("size", np.size(data['y'])):
numpyro.sample("obs27", dist.Normal(jnp.matmul(data['x'],w)-b,1.0), obs=data['y'])
mcmc = MCMC(numpyro.infer.NUTS(model),num_samples=1000,num_warmup=1000,num_chains=4)
mcmc.run(jax.random.PRNGKey(2080339834))
params = mcmc.get_samples(group_by_chain=True)
mcmc.print_summary()
NumPyro output:
mean std median 5.0% 95.0% n_eff r_hat
b -1.69 9.71 -1.69 -17.84 13.39 837.72 1.00
w[0] 9.96 8.76 9.98 -3.39 24.46 773.45 1.00
w[1] -15.62 7.25 -15.68 -27.79 -4.56 816.07 1.00
w[2] 4.53 6.98 4.56 -6.41 15.65 773.57 1.00
w[3] 5.05 4.34 5.07 -1.90 12.12 894.61 1.00
w[4] -1.49 3.60 -1.48 -7.36 4.41 1413.75 1.00
w[5] 23.85 7.93 23.81 11.97 37.11 831.59 1.00
w[6] 18.87 8.01 18.98 6.06 31.51 777.94 1.00
w[7] 4.07 3.26 4.10 -1.31 9.27 881.17 1.00
w[8] -4.01 3.05 -4.00 -8.93 0.90 881.75 1.00
w[9] -0.50 1.98 -0.54 -3.83 2.62 1546.71 1.00