Hi all,
It seems that pm.Minibatch may not be correctly handling batches of data and might be training the model on only one batch of data (making the size of the total training data equal to the batch size). This is only my suspicion (given the model outputs and warning message I’m receiving) as I am unsure what is happening in the backend.
Here are the codes to generate dummy data and to run the model using minibatch:
# generate data
N = 10000
P = 3
rng = np.random.default_rng(88)
X = rng.uniform(2, 10, size=(N, 3))
beta = np.array([1.5, 0.2, -0.9])
y = np.matmul(X, beta) + rng.normal(0, 1, size=(N,))
# minibatch
X_mb = pm.Minibatch(X, batch_size=100)
y_mb = pm.Minibatch(y, batch_size=100)
# model with minibatch
with pm.Model() as model_mb:
b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
sigma = pm.HalfCauchy("sigma", 1)
mu = pt.matmul(X_mb, b)
likelihood = pm.Normal(
"likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
)
fit_mb = pm.fit(
n=100000,
method="advi",
progressbar=True,
callbacks=[pm.callbacks.CheckParametersConvergence()],
random_seed=88,
)
idata_mb = fit_mb.sample(500)
Output:
100.00% [100000/100000 00:34<00:00 Average Loss = 287.9]
Finished [100%]: Average Loss = 287.88
UserWarning: Could not extract data from symbolic observation likelihood warnings.warn(f"Could not extract data from symbolic observation {obs}")
Here are the codes for the same model using the full data and without minibatch:
# model no minibatch
with pm.Model() as model:
b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
sigma = pm.HalfCauchy("sigma", 1)
mu = pt.matmul(X, b)
likelihood = pm.Normal("likelihood", mu=mu, sigma=sigma, observed=y, total_size=N)
fit = pm.fit(
n=100000,
method="advi",
progressbar=True,
callbacks=[pm.callbacks.CheckParametersConvergence()],
random_seed=88,
)
idata = fit.sample(500)
Output
100.00% [100000/100000 03:44<00:00 Average Loss = 14,196]
When comparing the posteriors of both models with the true beta parameter, the model using minibatch data performs poorly.
# compare models
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10, 4), layout="constrained")
az.plot_posterior(
idata_mb,
var_names="b",
ref_val=beta.tolist(),
ax=ax[0, :],
textsize=8,
)
az.plot_posterior(idata, var_names="b", ref_val=beta.tolist(), ax=ax[1, :], textsize=8)
for i in range(3):
ax[1, i].set_xlim(ax[0, i].get_xlim())
ax[0, 0].annotate(
text="Minibatch",
xy=(-0.5, 0.5),
xycoords="axes fraction",
rotation=90,
size=15,
fontweight="bold",
va="center",
)
ax[1, 0].annotate(
text="Full Data",
xy=(-0.5, 0.5),
xycoords="axes fraction",
rotation=90,
size=15,
fontweight="bold",
va="center",
)
Any suggestions on how to overcome this issue?
Thanks