Improving NUTS Runtime

Hi everyone,

I’ve been using PyMC3 with NUTS for quite a large hierarchical model. A few features of the model/inference:

  • there are approximately 4000 observations.
  • there are a large number of latent variables, approximately 8000.
  • the purpose of the model is to infer a small number of these latent variables (10). There are weak correlations in the posterior (correlation coefficients ~0.4) in this reduced set of variables.
  • on my local machine, profiling the model takes ~1.5 seconds for 1000 evaluations of model.logpt
  • using normal priors, with negative binomial output distributions.
  • the first few 1000 samples or so (during tuning) are very fast, and then it slows down significantly.

The performance I get with this model is that 12000 samples (4000 + 8000 samples, across four chains) takes about 5 hours. I’m not sure if this runtime is to be expected when the model is this large, but it would be great to see if it can be reduced. The model itself seems to be running fairly fast, but the performance of NUTS is bad. I’ve attached plots of the sampler statistics, but I don’t know how to interpret them very well.

The typical advice that I see is that re-parameterise the model, especially to avoid ‘funnels’ with scale hyperparameters. I don’t think that my model has these hyperparameters. I’m not sure how else to try and optimise the model, and because there are such a large number of parameters, I’m not sure how to investigate what is causing the long runtime (e.g., computing pairwise plots for all values would give a huge number of plots!).

I looked into trying the experimental jitter+adapt_full NUTS initialisation, but I found that performance for this was actually significantly worse (which I’m surprised by).

Model code: CM_Combined_Final Class, Github Link

One option would be to manually set the mass matrix myself using a posterior trace that I have. I don’t actually know how to do this, and whether this would give a big improvement.


The slowness is mainly from the number of leapfrog it is taking - you can see that from tree_size, which are all at 1024 (you probably get a fair among of warning related to that as well).

Usually the recommendation is reparameterization - as the reason that it is taking so many leapfrogs is the small step size that tuning adapt to, and the reason for that is bad geometry. I think the adapt_full works poor here is because the large among of parameters (thus the full mass matrix is not estimated well). I would suggest start with looking at the pairplot to see which parameters are correlated (you might not want to use the arviz pairplot for that as it will be very slow for large number of parameters), and reparameterized those. Otherwise, try that adapt a low rank mass matrix.


Thanks for the suggestions. I’m not sure what sort of re-parameterisations are typically suggested (I usually see people suggesting switching between centered and non-centered parameterisations).

I also typically set target_accept=0.9, I suppose this could be reduced? I don’t get any divergent transitions.

Centered to non-centered is one, otherwise a couple of things to try:

  • More informative prior (cut out tail area of some parameters)
  • reduce number of parameter (especially if you have normalization like softmax in your model)

Do you have divergent at target_accept=0.8?

1 Like

Thanks again - I’ll take a look and see if I can modify the parameterisation at all.

The model has a random walk in it; I would expect that this gives high correlations?

Another question: is it possible to directly set the mass matrix, and/or estimate it from posterior samples?

Not necessary - you can do x = pm.Normal(‘x’) and x_random_walk = tt.cumsum(x)

Yes - for example Initialzing NUTS with covariance matrix as mass matrix

1 Like