I am running a Bayesian logistic regression model with 800,000 observations and I want to use variational methods to estimate the posterior of parameters. I am running this on an online platform, and it is really slow to finish 10,000 iterations in the variational Bayes (more than 100 hour and the module dies during the running). Is there a possible way to add breakpoints in the iteration and rerun from the iteration? Or is there any possible way to make the variational method faster in an online platform?
with pm.Model() as logistic_model:
beta_0=pm.Normal('beta_0', 0, 4)
beta_1=pm.Normal('beta_1', 0, 4)
beta_2=pm.Normal('beta_2', 0, 4)
feature_1 = pm.Data("feature_1", value = X_train['feature_1'], mutable = True)
feature_2 = pm.Data("feature_2", value = X_train['feature_2'], mutable = True)
label = pm.Data("label", value = y_train, mutable = True)
observed=pm.Bernoulli("binary_label", pm.math.sigmoid(beta_0 + beta_1 * feature_1 + beta_2 * feature_2), observed = label)
with logistic_model:
mean_field = pm.fit(n=10000, method='advi')