I’m trying to understand why minibatching isn’t working for my problem. I have some clicks and transactions for each entity. I’m using a Binomial likelihood distribution to model the conversion rate (transactions / clicks). When I don’t use minibatching (or NUTS sampling) I get results pretty close to the generated conversion rates, but when I use batching I get results for each entity that are right around the mean conversion rate.
One thing that I thought might solve the problem, even if I didn’t think it would speed up the fitting, was using a batch size of 1. For reasons that I don’t understand, that was actually significantly worse, and it appeared the model didn’t learn anything from the data and just used the prior mean of 0. As it trained the loss only seemed to go up. Obviously there is something that I’m completely misunderstanding about minibatching.
This is how I’m generating the data
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
n_ent = 1000
entity_ids = np.arange(n_ent)
cvrs = sigmoid(np.random.normal(0, 1, size=(n_ent)) - 4)
clicks = np.round(np.random.lognormal(7, 1, size = n_ent)).astype(int)
transactions = np.random.binomial(clicks, cvrs)
X = pd.DataFrame({“entity_id”: entity_ids, “cvrs”: cvrs, “clicks”: clicks, “transactions”: transactions})
This is the model I’m using:
coords = {“entity_ids”: entity_ids}
with pm.Model(coords = coords) as model:clicks = pm.Data("clicks", X.clicks.astype(np.int32)) transactions = pm.Data("transactions", X.transactions.astype(np.int32)) entity_id = pm.Data("entity_id", X.entity_id.astype(np.int32)) mu = pm.Normal("beta", mu = 0, sigma = 10, dims = "entity_ids") cvr = pm.Deterministic("cvr", pm.math.invlogit(mu[entity_id])) y_hat = pm.Binomial("y_hat", p = cvr, n = clicks, observed = transactions)
This is the optimization code that works
with model:
advi = pm.ADVI() approx = advi.fit(100000, callbacks=[pm.variational.callbacks.CheckParametersConvergence(tolerance = .01)]) trace = approx.sample(1000)
And this is the code that doesn’t
with model:
minibatch_data_clicks, minibatch_data_transactions, minibatch_data_entity_id = pm.Minibatch(X.clicks.astype(np.int32), X.transactions.astype(np.int32), X.entity_id.astype(np.int32), batch_size = 128)advi = pm.ADVI() approx = advi.fit(100000,more_replacements = {model.clicks: minibatch_data_clicks, model.transactions: minibatch_data_transactions, model.entity_id: minibatch_data_entity_id}, callbacks=[pm.variational.callbacks.CheckParametersConvergence(tolerance = .01)]) trace = approx.sample(1000)