Verifying that minibatch is actually randomly sampling

sure thing @ricardoV94 (sorry for the delay).

there’s a gist here: example_for_ricardo.ipynb · GitHub

however I’ll write out what I consider the strongest evidence here, for posterity:

import numpy as np
import seaborn
import pymc as pm
import matplotlib.pyplot as plt

#generate fake feature data - sparse, binary, high dimensional.
nfeats = 512
Xs = []
for i in range(100):
    token_occurence = np.random.uniform(0,0.02, nfeats)
    on_probability = np.random.uniform(0, 1, (10, nfeats))
    Xs.append(on_probability<token_occurence)
X = np.vstack(Xs)

#generate some fake outputs
coeffs = np.random.normal(0,1, nfeats)
intercept = np.random.uniform(-10, 10)

y = X.dot(coeffs)+intercept

#now we have X,y pairs.

##Let's sample with a batch size of 20. 
##it's small, but there are enough iterations that we should sample 
##all the available features pretty thoroughly by rotating the batch indices. 
##but, looking at the learned posterior, we can see that many
##of the features are resemble the prior - they haven't encountered
##any data at all. 
minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=20,)

with pm.Model() as model:
    sigma = pm.HalfCauchy("sigma", beta=3)
    weights = pm.Normal("W", 0, sigma=sigma, shape=512)
    bias = pm.Normal("b", 0, 1)
    y_pred = (weights*minibatch_x).sum(1) + bias
    likelihood = pm.Normal("y", 
                        mu=y_pred, 
                        sigma=sigma, 
                        observed=minibatch_y, 
                        total_size=len(y)
                        )
    
    mean_field = pm.fit(200_000)
mean_field_samples = mean_field.sample(draws=5000)


plt.figure(figsize=(10, 5))
seaborn.boxplot(
    np.concatenate(mean_field_samples['posterior']['W'])[:, :30]
)
plt.xlabel('Categorical variable index')
plt.ylabel('Posterior distribution of learned weights')

#plot actual cooefficients:
plt.plot(coeffs[:30], '-o', label='Actual coefficients')
plt.legend()

##now bump up the batch size to 200. our previous problem has disappeared. 
##in this case, the model is seeing more of the features simply because it has a larger
##batch size. Thus mini-batching was not successfully rotating the sample indices. 
minibatch_x, minibatch_y = pm.Minibatch(X, y, batch_size=200,)

with pm.Model() as model:
    sigma = pm.HalfCauchy("sigma", beta=3)
    weights = pm.Normal("W", 0, sigma=sigma, shape=512)
    bias = pm.Normal("b", 0, 1)
    y_pred = (weights*minibatch_x).sum(1) + bias
    likelihood = pm.Normal("y", 
                        mu=y_pred, 
                        sigma=sigma, 
                        observed=minibatch_y, 
                        total_size=len(y)
                        )
    
    mean_field = pm.fit(20_000)
mean_field_samples = mean_field.sample(draws=5000)


plt.figure(figsize=(10, 5))
seaborn.boxplot(
    np.concatenate(mean_field_samples['posterior']['W'])[:, :30]
)
plt.xlabel('Categorical variable index')
plt.ylabel('Posterior distribution of learned weights')

#plot actual cooefficients:
plt.plot(coeffs[:30], '-o', label='Actual coefficients')
plt.legend()

outputs are:

finally: something I should have noticed earlier. Every time I call pm.fit I get a warning:

/Users/ljmartin/miniforge3/envs/compchem/lib/python3.11/site-packages/pymc/pytensorf.py:845:
UserWarning: RNG Variable RandomGeneratorSharedVariable(<Generator(PCG64) at 0x31A16D9A0>) has multiple clients. 
This is likely an inconsistent random graph.
  warnings.warn(

any help interpreting that would be fantastic.

thank you!