Hi all,
I have a similar suspicion to a recent post: Minibatch not working
That is, I suspect minibatch is just picking one random sample at the start and sticking to that - leading to overfitting on those instances that were picked at the start, but not actually sampling across all instances.
Why? Because even after fitting a converged model using a reasonable batch size, I only get good predictions for approximately the same number of instances as the batch size itself. It’s as if all the other instances are a hold-out test set.
As an example, here’s the prediction results of a simple categorical regression problem (features X
are categorical, outputs y
are floats). I set the batch size to a range of values from small to reasonable (5, 25, 45, 65, 85). The line shows the absolute error of all the first 100 instances, ordered by error. As you can see, it looks as though only N=batch_size number of instances are really being seen by the regressor - they get a low error, and it quickly pivots upwards when considering anything greater than N instances.
Arguably increasing the batch size is just creating a better model, but the shape of the elbow in each curve lines up so neatly with batch size, that it seems too much of a coincidence! I would expect an increasingly better model would generalize to more samples than just the batch size.
So: is there any way to verify that minibatch is choosing new instances in each iteration? Can I print the indices of the minibatch variable at each step?
code:
errs = []
for bsize in [5, 25, 45, 65, 85]:
print(bsize)
minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=bsize)
with pm.Model() as model:
sigma = pm.HalfCauchy("sigma", beta=3)
weights = pm.Normal("W", 0, sigma=sigma, shape=512)
bias = pm.Normal("b", 0, 1)
y_pred = (weights*minibatch_x).sum(1) + bias
likelihood = pm.Normal("y",
mu=y_pred,
sigma=sigma,
observed=minibatch_y,
total_size=len(y)
)
mean_field = pm.fit(200_000)
#draw from posterior to get a range of values for the weights and
#bias, with which to make new predictions
mean_field_samples = mean_field.sample(draws=5000)
w_samples = mean_field_samples['posterior']['W'][0].values
b_samples = mean_field_samples['posterior']['b'][0].values
ps = []
for i in range(1000):
idx = np.random.choice(w_samples.shape[0])
w_ = w_samples[idx]
b_ = b_samples[idx]
p = w_.dot(diff.T) + b_
ps.append(p)
ps = np.array(ps)
#take the mean prediction to calculate the error.
err = ps.mean(0)-y
errs.append(err)
for err, batch_size in zip(errs, [5, 25, 45, 65, 85]):
plt.plot(np.abs(err)[np.abs(err).argsort()][:100], label=batch_size)
plt.ylim(0, 0.5)
for c, batch_size in enumerate([5, 25, 45, 65, 85]):
plt.axvline(batch_size, c=f'C{c}', linestyle='--')
plt.legend()
plt.xlabel('Instances (ordered by error)')
plt.ylabel('Absolute error')
one more test to convince you/me:
Run with just a reasonably small batch size of 20. Then look at the distribution of the weights vector in the posterior. We can see tha some categorical variables are fitted well. Others simply resemble the prior - it’s as if those categories haven’t been encountered by the model! Fitting to the full dataset with NUTS or ADVI does not reproduce this, indicating that those variables are indeed present in the data.
minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=20)
with pm.Model() as model:
sigma = pm.HalfCauchy("sigma", beta=3)
weights = pm.Normal("W", 0, sigma=sigma, shape=512)
bias = pm.Normal("b", 0, 1)
y_pred = (weights*minibatch_x).sum(1) + bias
likelihood = pm.Normal("y",
mu=y_pred,
sigma=sigma,
observed=minibatch_y,
total_size=len(y)
)
mean_field = pm.fit(200_000)
mean_field_samples = mean_field.sample(draws=5000)
plt.figure(figsize=(10, 5))
seaborn.boxplot(
np.concatenate(mean_field_samples['posterior']['W'])[:,:15]
plt.xlabel('Categorical variable index')
plt.ylabel('Posterior distribution of learned weights')
)