Hi everyone,
I have been experiencing some extremely slow sampling on a hierarchical logistic regression I’ve coded up, and I would really appreciate input from the experts! Either how I’m going wrong and how to speed it up, or perhaps an explanation as to why slow sampling would be expected.
I’m trying to build a multi-level logistic model, with:
- 1 independent variable
- c. 60k data points
- 1 node at the top of the tree (
level_1_count == 1
in the code below) - 27 nodes on the second level (
level_2_count == 27
in the code below) - 86 nodes on the third level (
level_3_count == 86
in the code below)
Here’s the model:
a_mu = 0
a_sig = 5
b_mu = -0.5
b_sig = 0.1
with hierarchical_model:
# Tree levels
a_1 = pm.Normal('a_1', mu=a_mu, sigma=a_sig, shape=level_1_count)
b_1 = pm.Normal('b_1', mu=b_mu, sigma=b_sig, shape=level_1_count)
a_2_offset = pm.Normal('a_2_offset', mu=0, sigma=1, shape=level_2_count)
a_2 = pm.Deterministic('a_2', a_1[level_2_link - 1] + a_2_offset * a_sig)
b_2_offset = pm.Normal('b_2_offset', mu=0, sigma=1, shape=level_2_count)
b_2 = pm.Deterministic('b_2', b_1[level_2_link - 1] + b_2_offset * b_sig)
a_3_offset = pm.Normal('a_3_offset', mu=0, sigma=1, shape=level_3_count)
a_3 = pm.Deterministic('a_3', a_2[level_3_link - 1] + a_3_offset * a_sig)
b_3_offset = pm.Normal('b_3_offset', mu=0, sigma=1, shape=level_3_count)
b_3 = pm.Deterministic('b_3', b_2[level_3_link - 1] + b_3_offset * b_sig)
# Sigmoid function
d = a_3[level_3 - 1] + (b_3[level_3 - 1] * df2['X'])
sig = pm.Deterministic('sig', pm.math.sigmoid(d))
# Likelihood
Y_like = pm.Bernoulli('Y_like', p=sig, observed=df2['Y'])
# Sample
trace = pm.sample(draws=1000, tune=2000, chains=2, target_accept=0.9, return_inferencedata=False, random_seed=42)
The above code, on my Windows laptop with an i7 processor and 16GB RAM took over 12 hours to run! So that’s an unbelievably low sampling rate - at least, as far as I can tell. The two-level equivalent (so removing the middle layer) takes about 3 hours to run, which is still too slow to be really practical.
The model runs without divergences and learns sensible posterior distributions. However, initially I did get some divergences due to a misspecification error, and in the journey to fixing that I also learned about non-centered parameterizations, and higher-than-default target_accept
and tune
. I kept them in, even once I fixed the specification, because I gather they’re a good idea for models with complex posteriors, like the above. Please note that removing them doesn’t increase the sampling speed.
For other models on my laptop, sampling with the same settings can be fast:
- Non-hierarchical logistic regression for one node at the bottom of the above tree (c. 30s to run per node). NB, if I do all 86 nodes simultaneously (but in a non-hierarchical model), it takes about 2 hours again.
- Hierarchical linear regressions e.g. from some of the tutorials on the pymc3 website - I get better sampling speeds than those listed on the pages
I’m therefore tempted to conclude that it isn’t a technical issue creating poor performance on my laptop. Data quality/volume, or fit of the model to the underlying data, doesn’t seem to be the underlying issue either, since I replicated an example, 2-level logistic regression presented by Nicole Carlson here (42:20 onwards). Her notebook (or at least some version of it) is here. This also samples slowly (c. 2-3 hours to run) for me, despite the data being synthetically generated as if the logistic model were ‘true’. Note that in the video, I believe, she does not actually get to sampling - so I can’t compare our relative speeds.
I’ve also tried different settings on the priors, to no avail. Note that the input feature is currently unstandardized - I tried that too but it didn’t help either.
My sense is that this model is just fundamentally difficult to sample from, with the complexity increasing with each level you add to the tree, but I don’t have a good theoretical understanding for why that’s the case - and it might not be!
I’d love some input on whether the slow sampling is due to my code being off, the nature of the model, or something else. Many thanks in advance!
NB: Unfortunately, I can’t share the dataset since it’s confidential.