Minibatch not working

Thanks @ricardoV94, and thank you @Yamaguchi for running some tests and finding out what the problem was.

Minibatch 2 and the Full Data are equivalent when adjusting the axis.

But now I ran into another problem: the model doesn’t seem to be registering/storing observed variables afterwards. So, when using pm.sample_posterior_predictive, it only uses a subset of the data where the subset is the same size as the minibatch size. And, idata xarray doesn’t have the observed_data group.

# minibatch
X_mb, y_mb = pm.Minibatch(X, y, batch_size=100)

# model with minibatch
with pm.Model() as model_mb2:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pt.matmul(X_mb, b)
    likelihood = pm.Normal(
        "likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
    )

    fit_mb2 = pm.fit(
        n=100000,
        method="advi",
        progressbar=True,
        callbacks=[pm.callbacks.CheckParametersConvergence()],
        random_seed=88,
    )
    idata_mb2 = fit_mb2.sample(500)


with model_mb2:
    pm.sample_posterior_predictive(idata_mb2, extend_inferencedata=True)

Giving this message:

UserWarning: Could not extract data from symbolic observation likelihood
  warnings.warn(f"Could not extract data from symbolic observation {obs}")
print("posterior_predictive dims:", dict(idata_mb2.posterior_predictive.dims))
print("posterior dims", dict(idata_mb2.posterior.dims))

posterior_predictive dims: {‘chain’: 1, ‘draw’: 500, ‘likelihood_dim_2’: 100}
posterior dims {‘chain’: 1, ‘draw’: 500, ‘b_dim_0’: 3}

So, a workaround for the time being for me is:

with pm.Model() as model_new:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pt.matmul(X, b)
    y_pred = pm.Normal("y_pred", mu=mu, sigma=sigma, observed=y, shape=y.shape)

    ppc_new = pm.sample_posterior_predictive(
        idata_mb,
        var_names=["y_pred"],
        predictions=False,
    )

I open up a separate issue on github related to the model not storing the observed data.

Thanks.