Strange results with simple linear reg + variational inference

Hi everyone!

I am fitting a very simple linear regression model:

    with pm.Model() as model:

        sigma = pm.HalfNormal('sigma', sigma=100)
        a_0 = pm.Normal('a_0', mu=0, sigma=100)
        a_1 = pm.Normal('a_1', mu=0, sigma=100)
        
        pm.Normal(
            'outcome', 
            mu=a_0+a_1*length_i, 
            sigma=sigma, 
            observed=out_i,
            dims='obs'
        )

With data produced by:

LoT_lengths = np.arange(10)
cat_i = np.random.randint(0,10,100)
length_i = LoT_lenghts[cat_i]
out_i = 30 + 70 * np.random.normal(loc=length_i, scale=1)

When I sample with NUTS, everything works as expected:

    with model:
        trace = pm.sample(
            1000, 
            cores=1, 
            return_inferencedata=False,
            target_accept=0.95
        )

And plotting the data:

trace_az = az.from_pymc3(trace, model=model)
a0_trace = trace_az.posterior['a_0'].values.flatten()
a1_trace = trace_az.posterior['a_1'].values.flatten()
sigma_trace = trace_az.posterior['sigma'].values.flatten()

plt.scatter(xs_fake, out_i)

xs = np.linspace(0,10,2)
for a0,a1,s in zip(a0_trace, a1_trace, sigma_trace):
    plt.plot(
        xs,
        a0+a1*xs,
        color='blue',
        alpha=0.05,
        linewidth=1
    )

image

However, when I fit the same model with variational API:

with model:
    fit = pm.fit()
fit_samples = fit.sample()
trace_az = az.from_pymc3(fit_samples, model=model)

the result looks really wrong (same plotting code as above):
image

Despite the model having apparently converged:
image

Now, I know that VI can be bad in terms of disregarding bimodality and when there are weird geometries at play, but it seems strange that it would be quite so bad in such a simple case! I was wondering if I am doing something wrong. Thank you very much for your help!