Running this model with Pymc4 leads to an initialization error. It seems that large values from LogNormal overflow to infinity
. If I change LogNormal
to smaller values like ~1-5, then it works. Interestingly,running the same model with Stan works fine.
Any idea why this is happening? Is this a overflow issue? Can this be a bug?
import pymc as pm
import aesara
import arviz as az
import os
import json
import random
import sys
import numpy as np
random.seed(685143202)
np.random.seed(28571822)
data = dict()
data['x'] =aesara.shared(np.array([[0.3127,0.7185,0.9494,0.5597,0.7358,0.1210,0.6191,0.5667,0.4503,0.5986],[0.0269,0.7518,0.9182,0.6499,0.4869,0.3073,0.8153,0.6999,0.0878,0.1670],[0.7447,0.5607,0.0876,0.7848,0.5527,0.6284,0.2125,0.6352,0.8686,0.7983],[0.1576,0.0972,0.6508,0.7716,0.6274,0.1363,0.7613,0.2494,0.3428,0.9261],[0.7717,0.4577,0.0172,0.8652,0.8721,0.1606,0.9938,0.0563,0.5016,0.5695],[0.1718,0.2058,0.9203,0.6626,0.5851,0.1158,0.4355,0.0705,0.7436,0.2993],[0.3664,0.5945,0.0679,0.5061,0.9727,0.2606,0.8866,0.7922,0.8299,0.4857],[0.1521,0.6116,0.8895,0.8145,0.8965,0.3734,0.5860,0.2676,0.9310,0.6909],[0.6845,0.6834,0.3010,0.4828,0.8348,0.3919,0.4748,0.6864,0.1595,0.2750],[0.1263,0.8909,0.6362,0.1842,0.4666,0.9637,0.3681,0.4188,0.3020,0.7719]]).astype("float64"))
data['weight'] =aesara.shared(np.array([3.9238,3.1421,6.5905,2.8718,1.7819,1.4975,0.1834,7.6811,5.0964,2.7595]).astype("float64"))
data['y'] =aesara.shared(np.array([14.3773,22.0280,21.5997,24.4789,25.8981,15.0056,20.0997,17.3859,20.7086,22.1846]).astype("float64"))
with pm.Model() as model:
w_param_0=pm.Lognormal('w_param_0',54.28338623046875,39.43290710449219)
w=pm.Normal('w',w_param_0,10.0, shape=(10,))
b=pm.Normal('b',1.0,10.0)
obs1=pm.Normal('obs1',aesara.tensor.dot(data['x'],w)+b,1.0, observed=data['y'])
with model:
step = pm.step_methods.hmc.nuts.NUTS(max_treedepth=10,target_accept=0.8,step_scale=0.25)
samples = pm.sample(draws=1000, chains=4, tune=1000, step=step, init="auto", jitter_max_retries=10)
Output:
WARNING (aesara.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Traceback (most recent call last):
File "pymc4file.py", line 33, in <module>
samples = pm.sample(draws=1000, chains=4, tune=1000, step=step, init="auto", jitter_max_retries=10)
File "pymc/pymc/sampling.py", line 539, in sample
model.check_start_vals(ip)
File "pymc/pymc/model.py", line 1709, in check_start_vals
"Initial evaluation of model at starting point failed!\n"
pymc.exceptions.SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'w_param_0_log__': array(831.76046759), 'w': array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf]), 'b': array(1.)}
Initial evaluation results:
{'w_param_0': -198.96, 'w': nan, 'b': -3.22, 'obs1': -inf}
Stan results:
data{
matrix[10,10] x;
real weight[10];
vector[10] y;
}
parameters{
vector[10] w;
real b;
real<lower=0> w_param_0;
}
model{
w_param_0~lognormal(54.28338623046875,39.43290710449219);
w~normal(w_param_0,10.0);
b~normal(1.0,10.0);
y~normal(x*w+b,1.0);
}
Stan Output:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
w[1] 1.4e-01 1.7e-01 4.6e+00 -7.7e+00 2.0e-01 7.8 7.3e+02 8.2e+02 1.0e+00
w[2] -6.5e+00 1.5e-01 4.1e+00 -1.3e+01 -6.5e+00 0.21 8.0e+02 9.0e+02 1.0e+00
w[3] -3.8e+00 1.4e-01 3.7e+00 -1.0e+01 -3.7e+00 2.4 6.9e+02 7.7e+02 1.0e+00
w[4] 8.1e+00 8.2e-02 3.1e+00 2.8e+00 8.1e+00 13 1.4e+03 1.6e+03 1.0e+00
w[5] -1.1e+00 7.5e-02 3.2e+00 -6.3e+00 -1.0e+00 4.0 1.8e+03 2.0e+03 1.0e+00
w[6] 1.3e+01 1.6e-01 4.5e+00 6.1e+00 1.3e+01 21 7.7e+02 8.6e+02 1.0e+00
w[7] 9.5e+00 1.6e-01 4.2e+00 2.4e+00 9.5e+00 17 6.9e+02 7.7e+02 1.0e+00
w[8] 4.1e-01 6.4e-02 2.1e+00 -3.2e+00 4.1e-01 4.0 1.1e+03 1.2e+03 1.0e+00
w[9] -6.5e+00 6.4e-02 2.0e+00 -9.9e+00 -6.5e+00 -3.2 1.0e+03 1.1e+03 1.0e+00
w[10] 1.2e+00 4.1e-02 1.7e+00 -1.8e+00 1.3e+00 4.0 1.7e+03 1.9e+03 1.0e+00
b 1.4e+01 2.0e-01 5.7e+00 4.4e+00 1.4e+01 23 8.2e+02 9.1e+02 1.0e+00
w_param_0 3.4e-01 2.3e-02 1.0e+00 1.3e-21 9.8e-06 2.4 1.9e+03 2.1e+03 1.0e+00
Versions and main components:
- PyMC Version: v4.0.0b3
- Aesara Version: 2.4.0
- Python Version: 3.7.11
- Operating system: Ubuntu 18.04