Verifying that minibatch is actually randomly sampling

Hi all,
I have a similar suspicion to a recent post: Minibatch not working

That is, I suspect minibatch is just picking one random sample at the start and sticking to that - leading to overfitting on those instances that were picked at the start, but not actually sampling across all instances.

Why? Because even after fitting a converged model using a reasonable batch size, I only get good predictions for approximately the same number of instances as the batch size itself. It’s as if all the other instances are a hold-out test set.

As an example, here’s the prediction results of a simple categorical regression problem (features X are categorical, outputs y are floats). I set the batch size to a range of values from small to reasonable (5, 25, 45, 65, 85). The line shows the absolute error of all the first 100 instances, ordered by error. As you can see, it looks as though only N=batch_size number of instances are really being seen by the regressor - they get a low error, and it quickly pivots upwards when considering anything greater than N instances.

Arguably increasing the batch size is just creating a better model, but the shape of the elbow in each curve lines up so neatly with batch size, that it seems too much of a coincidence! I would expect an increasingly better model would generalize to more samples than just the batch size.

So: is there any way to verify that minibatch is choosing new instances in each iteration? Can I print the indices of the minibatch variable at each step?

code:

errs = []
for bsize in [5, 25, 45, 65, 85]:
    print(bsize)
    minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=bsize)

    with pm.Model() as model:
        sigma = pm.HalfCauchy("sigma", beta=3)
        weights = pm.Normal("W", 0, sigma=sigma, shape=512)
        bias = pm.Normal("b", 0, 1)
        y_pred = (weights*minibatch_x).sum(1) + bias
        likelihood = pm.Normal("y", 
                            mu=y_pred, 
                            sigma=sigma, 
                            observed=minibatch_y, 
                            total_size=len(y)
                            )
        
        mean_field = pm.fit(200_000)

    #draw from posterior to get a range of values for the weights and 
    #bias, with which to make new predictions
    mean_field_samples = mean_field.sample(draws=5000)
    w_samples = mean_field_samples['posterior']['W'][0].values
    b_samples = mean_field_samples['posterior']['b'][0].values

    ps = []
    for i in range(1000):
        idx = np.random.choice(w_samples.shape[0])
        w_ = w_samples[idx]
        b_ = b_samples[idx]
        p = w_.dot(diff.T) + b_
        ps.append(p)
    ps = np.array(ps)

    #take the mean prediction to calculate the error.
    err = ps.mean(0)-y
    errs.append(err)
for err, batch_size in zip(errs, [5, 25, 45, 65, 85]):
    plt.plot(np.abs(err)[np.abs(err).argsort()][:100], label=batch_size)
plt.ylim(0, 0.5)
for c, batch_size in enumerate([5, 25, 45, 65, 85]):
    plt.axvline(batch_size, c=f'C{c}', linestyle='--')
plt.legend()
plt.xlabel('Instances (ordered by error)')
plt.ylabel('Absolute error')

one more test to convince you/me:
Run with just a reasonably small batch size of 20. Then look at the distribution of the weights vector in the posterior. We can see tha some categorical variables are fitted well. Others simply resemble the prior - it’s as if those categories haven’t been encountered by the model! Fitting to the full dataset with NUTS or ADVI does not reproduce this, indicating that those variables are indeed present in the data.

minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=20)

with pm.Model() as model:
    sigma = pm.HalfCauchy("sigma", beta=3)
    weights = pm.Normal("W", 0, sigma=sigma, shape=512)
    bias = pm.Normal("b", 0, 1)
    y_pred = (weights*minibatch_x).sum(1) + bias
    likelihood = pm.Normal("y", 
                        mu=y_pred, 
                        sigma=sigma, 
                        observed=minibatch_y, 
                        total_size=len(y)
                        )
    
    mean_field = pm.fit(200_000)
mean_field_samples = mean_field.sample(draws=5000)

plt.figure(figsize=(10, 5))
seaborn.boxplot(
    np.concatenate(mean_field_samples['posterior']['W'])[:,:15]
plt.xlabel('Categorical variable index')
plt.ylabel('Posterior distribution of learned weights')
)

forgot to add:

pm.__version__
>>> 5.11.0

CC @ferrine

no updates to add from my end unfortunately, but the issue persists

Can you provide a small fully reproducible script?

sure thing @ricardoV94 (sorry for the delay).

there’s a gist here: example_for_ricardo.ipynb · GitHub

however I’ll write out what I consider the strongest evidence here, for posterity:

import numpy as np
import seaborn
import pymc as pm
import matplotlib.pyplot as plt

#generate fake feature data - sparse, binary, high dimensional.
nfeats = 512
Xs = []
for i in range(100):
    token_occurence = np.random.uniform(0,0.02, nfeats)
    on_probability = np.random.uniform(0, 1, (10, nfeats))
    Xs.append(on_probability<token_occurence)
X = np.vstack(Xs)

#generate some fake outputs
coeffs = np.random.normal(0,1, nfeats)
intercept = np.random.uniform(-10, 10)

y = X.dot(coeffs)+intercept

#now we have X,y pairs.

##Let's sample with a batch size of 20. 
##it's small, but there are enough iterations that we should sample 
##all the available features pretty thoroughly by rotating the batch indices. 
##but, looking at the learned posterior, we can see that many
##of the features are resemble the prior - they haven't encountered
##any data at all. 
minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=20,)

with pm.Model() as model:
    sigma = pm.HalfCauchy("sigma", beta=3)
    weights = pm.Normal("W", 0, sigma=sigma, shape=512)
    bias = pm.Normal("b", 0, 1)
    y_pred = (weights*minibatch_x).sum(1) + bias
    likelihood = pm.Normal("y", 
                        mu=y_pred, 
                        sigma=sigma, 
                        observed=minibatch_y, 
                        total_size=len(y)
                        )
    
    mean_field = pm.fit(200_000)
mean_field_samples = mean_field.sample(draws=5000)


plt.figure(figsize=(10, 5))
seaborn.boxplot(
    np.concatenate(mean_field_samples['posterior']['W'])[:, :30]
)
plt.xlabel('Categorical variable index')
plt.ylabel('Posterior distribution of learned weights')

#plot actual cooefficients:
plt.plot(coeffs[:30], '-o', label='Actual coefficients')
plt.legend()

##now bump up the batch size to 200. our previous problem has disappeared. 
##in this case, the model is seeing more of the features simply because it has a larger
##batch size. Thus mini-batching was not successfully rotating the sample indices. 
minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=200,)

with pm.Model() as model:
    sigma = pm.HalfCauchy("sigma", beta=3)
    weights = pm.Normal("W", 0, sigma=sigma, shape=512)
    bias = pm.Normal("b", 0, 1)
    y_pred = (weights*minibatch_x).sum(1) + bias
    likelihood = pm.Normal("y", 
                        mu=y_pred, 
                        sigma=sigma, 
                        observed=minibatch_y, 
                        total_size=len(y)
                        )
    
    mean_field = pm.fit(20_000)
mean_field_samples = mean_field.sample(draws=5000)


plt.figure(figsize=(10, 5))
seaborn.boxplot(
    np.concatenate(mean_field_samples['posterior']['W'])[:, :30]
)
plt.xlabel('Categorical variable index')
plt.ylabel('Posterior distribution of learned weights')

#plot actual cooefficients:
plt.plot(coeffs[:30], '-o', label='Actual coefficients')
plt.legend()

outputs are:

finally: something I should have noticed earlier. Every time I call pm.fit I get a warning:

/Users/ljmartin/miniforge3/envs/compchem/lib/python3.11/site-packages/pymc/pytensorf.py:845:
UserWarning: RNG Variable RandomGeneratorSharedVariable(<Generator(PCG64) at 0x31A16D9A0>) has multiple clients. 
This is likely an inconsistent random graph.
  warnings.warn(

any help interpreting that would be fantastic.

thank you!

Yes that warning is pretty important and suggests an error indeed

OK, thanks Ricardo! That shifts my attention a little bit. Perhaps this whole issue should be about identifying the source of that warning first, rather than verifying that minibatch is rotating indices.

Do you have any insight into what causes the warning? I couldn’t find any previous issues with exactly this warning

I’ll have to investigate. Could I just ask you to test on the latest release of PyMC? We solved some bugs about minibatch indexing sometime ago, but I can’t tell whether before or after the version you were using.

sure thing - I upgraded to 5.16.1 using mamba update pymc, which I can see is the latest release on the github repo too, and can confirm I’m still seeing the same behaviour as well as the warning. Happy to try something else, too. BTW - was the reproducible script helpful? If you have time, I’d love to know if you see the same warning or if it’s specific to my set up (M1 MBP).

realised pytensor might be relevant too - I’ve got 2.23.0. Tried upgrading to 2.24.1 but it’s incompatible with pymc 5.16.1

@lewiso1 I was able to reproduce the problem with your gist, thanks so much! I’ll try to push a patch soon

Your plot looks like this after my bugfix:
image

Bugfix: Fix bug with multiple minibatch variables by ricardoV94 · Pull Request #7408 · pymc-devs/pymc · GitHub

Also apologies for taking so long, I see your first message was still in April ! :grimacing:

OMG - awesome! Thank you Ricardo - appreciate you taking a look and making the fix. Excited to use it again

edit: the fix works for me too :slight_smile:

1 Like