Hi everyone!
I am fitting a very simple linear regression model:
with pm.Model() as model:
sigma = pm.HalfNormal('sigma', sigma=100)
a_0 = pm.Normal('a_0', mu=0, sigma=100)
a_1 = pm.Normal('a_1', mu=0, sigma=100)
pm.Normal(
'outcome',
mu=a_0+a_1*length_i,
sigma=sigma,
observed=out_i,
dims='obs'
)
With data produced by:
LoT_lengths = np.arange(10)
cat_i = np.random.randint(0,10,100)
length_i = LoT_lenghts[cat_i]
out_i = 30 + 70 * np.random.normal(loc=length_i, scale=1)
When I sample with NUTS, everything works as expected:
with model:
trace = pm.sample(
1000,
cores=1,
return_inferencedata=False,
target_accept=0.95
)
And plotting the data:
trace_az = az.from_pymc3(trace, model=model)
a0_trace = trace_az.posterior['a_0'].values.flatten()
a1_trace = trace_az.posterior['a_1'].values.flatten()
sigma_trace = trace_az.posterior['sigma'].values.flatten()
plt.scatter(xs_fake, out_i)
xs = np.linspace(0,10,2)
for a0,a1,s in zip(a0_trace, a1_trace, sigma_trace):
plt.plot(
xs,
a0+a1*xs,
color='blue',
alpha=0.05,
linewidth=1
)
However, when I fit the same model with variational API:
with model:
fit = pm.fit()
fit_samples = fit.sample()
trace_az = az.from_pymc3(fit_samples, model=model)
the result looks really wrong (same plotting code as above):
Despite the model having apparently converged:
Now, I know that VI can be bad in terms of disregarding bimodality and when there are weird geometries at play, but it seems strange that it would be quite so bad in such a simple case! I was wondering if I am doing something wrong. Thank you very much for your help!