Slow sampling in hierarchical logistic regression

Hi everyone,

I have been experiencing some extremely slow sampling on a hierarchical logistic regression I’ve coded up, and I would really appreciate input from the experts! Either how I’m going wrong and how to speed it up, or perhaps an explanation as to why slow sampling would be expected.

I’m trying to build a multi-level logistic model, with:

  • 1 independent variable
  • c. 60k data points
  • 1 node at the top of the tree (level_1_count == 1 in the code below)
  • 27 nodes on the second level (level_2_count == 27 in the code below)
  • 86 nodes on the third level (level_3_count == 86 in the code below)

Here’s the model:

a_mu  = 0
a_sig = 5
b_mu  = -0.5
b_sig = 0.1

  with hierarchical_model:
      # Tree levels
      a_1        = pm.Normal('a_1', mu=a_mu, sigma=a_sig, shape=level_1_count)
      b_1        = pm.Normal('b_1', mu=b_mu, sigma=b_sig, shape=level_1_count)
      a_2_offset = pm.Normal('a_2_offset', mu=0, sigma=1, shape=level_2_count)
      a_2        = pm.Deterministic('a_2', a_1[level_2_link - 1] + a_2_offset * a_sig)
      b_2_offset = pm.Normal('b_2_offset', mu=0, sigma=1, shape=level_2_count)
      b_2        = pm.Deterministic('b_2', b_1[level_2_link - 1] + b_2_offset * b_sig)
      a_3_offset = pm.Normal('a_3_offset', mu=0, sigma=1, shape=level_3_count)
      a_3        = pm.Deterministic('a_3', a_2[level_3_link - 1] + a_3_offset * a_sig)
      b_3_offset = pm.Normal('b_3_offset', mu=0, sigma=1, shape=level_3_count)
      b_3        = pm.Deterministic('b_3', b_2[level_3_link - 1] + b_3_offset * b_sig)
      # Sigmoid function
      d = a_3[level_3 - 1] + (b_3[level_3 - 1] * df2['X'])
      sig = pm.Deterministic('sig', pm.math.sigmoid(d))

      # Likelihood
      Y_like = pm.Bernoulli('Y_like', p=sig, observed=df2['Y'])

      # Sample
      trace = pm.sample(draws=1000, tune=2000, chains=2, target_accept=0.9, return_inferencedata=False, random_seed=42)

The above code, on my Windows laptop with an i7 processor and 16GB RAM took over 12 hours to run! So that’s an unbelievably low sampling rate - at least, as far as I can tell. The two-level equivalent (so removing the middle layer) takes about 3 hours to run, which is still too slow to be really practical.

The model runs without divergences and learns sensible posterior distributions. However, initially I did get some divergences due to a misspecification error, and in the journey to fixing that I also learned about non-centered parameterizations, and higher-than-default target_accept and tune. I kept them in, even once I fixed the specification, because I gather they’re a good idea for models with complex posteriors, like the above. Please note that removing them doesn’t increase the sampling speed.

For other models on my laptop, sampling with the same settings can be fast:

  • Non-hierarchical logistic regression for one node at the bottom of the above tree (c. 30s to run per node). NB, if I do all 86 nodes simultaneously (but in a non-hierarchical model), it takes about 2 hours again.
  • Hierarchical linear regressions e.g. from some of the tutorials on the pymc3 website - I get better sampling speeds than those listed on the pages

I’m therefore tempted to conclude that it isn’t a technical issue creating poor performance on my laptop. Data quality/volume, or fit of the model to the underlying data, doesn’t seem to be the underlying issue either, since I replicated an example, 2-level logistic regression presented by Nicole Carlson here (42:20 onwards). Her notebook (or at least some version of it) is here. This also samples slowly (c. 2-3 hours to run) for me, despite the data being synthetically generated as if the logistic model were ‘true’. Note that in the video, I believe, she does not actually get to sampling - so I can’t compare our relative speeds.

I’ve also tried different settings on the priors, to no avail. Note that the input feature is currently unstandardized - I tried that too but it didn’t help either.

My sense is that this model is just fundamentally difficult to sample from, with the complexity increasing with each level you add to the tree, but I don’t have a good theoretical understanding for why that’s the case - and it might not be!

I’d love some input on whether the slow sampling is due to my code being off, the nature of the model, or something else. Many thanks in advance!

NB: Unfortunately, I can’t share the dataset since it’s confidential.

@cocodimama your sampling speeds may not be unusual at all given the number of data points (60K) and the number of levels (You have 1x27x86 = .2,322). On an average you have ~ 26 data points for every level. How many 1’s in the data relative to the 0’s? It is very likely that the distribution across all these levels is not balanced and that can increase sampling time.

Have you tried to model the same specification with a smaller data set of say 5K data points and observe the time taken? I have ran hierarchical logistic regression models with 64GB of RAM when modeling with 3K levels and 120K number of records and it took more than 24 hours to complete without the right priors. In your case, having only 1 predictor (IV) could be increasing the sampling time as a single predictor may not easily explain the variance across all the levels.

What do your priors look like? Are you using non-informative priors or informative priors? For such a problem informative priors may also help. Have you attempted to run the models in a batch mode? Or in a sequential mode where the posteriors of first model become the priors of the next model. You could for example, divide your data into 6 or 12 smaller data sets and try the sequential or batch modeling approach. Pymc3 has mini-batch ADVI as one of the samplers and may be a good fit for your problem.

Here is an example GLM: Mini-batch ADVI on a HIerarchical Regression Model. Also examine the document here Pymc3 API.

With your data set you can also consider using Pymc3 with JAX. JAX is designed to speed up sampling when working with large datasets. You can also consider BlackJAX. See another BlackJAX example. If you have access to GPU, then JAX/BlackJAX can speed up your models.

@sree_datta Thank you for your detailed suggestions. I confirmed that the slow speed was being caused by the data set size and I have implemented mini-batch ADVI. This runs much faster whilst still giving very similar estimates of the posterior means in question (c. 98% accurate), which is more than sufficient for my purposes. Many thanks once again.


@cocodimama good to see both of the confirmations about the impact of data set size on speed and the improved efficiency with mini-batch ADVI.