PyMC3 slows rapidly with increasing numbers of parameters

I haven’t look in detail at the model yet, but I can try to give you a couple of general hints.
How quickly NUTS samples depends less on the number of parameters, but more on the shape of the posterior. If the posterior is reasonably close to a normal without correlations, then 10000 parameters aren’t a problem. But if your posterior has certain features, things get much more difficult really quickly. Comparing sampling speed between NUTS and Metropolis is usually pretty pointless, Metropolis just doesn’t notice that its samples are bad and will happily give you tons of useless samples.
Things that make sampling hard, in no particular order:

  • Very different posterior variance in different variables. Most of the time we get around that using advi (we only use the result to rescale the variables), but in some cases advi gives stupid answers and then nuts gets into trouble. You can manually call pm.fit to have a look at what results you get and if they match the trace somewhat. (If you feel very adventurous you could also use the work in progess branch https://github.com/pymc-devs/pymc3/pull/2327, that adapts scaling after advi. Update this is no longer work-in-progress, since 3.2 this is the default)
  • Correlated variables. In low dimensions this isn’t usually that bad, but it can still slow down sampling. In cases like that you could try to find a parametrization such that the correlations in the posterior are smaller, or if n is small, you could use full-rank advi for initialization (the init param of pm.sample, 'nuts' also sets a full mass matrix to get rid of correlations)
  • “funnels”. That often happens with scale parameters in hierarchical models. The posterior of the other variables given a small value for a scale parameter might be very different that the posterior of the other variables given a large scale scale parameter. The scaling adaptation can’t work in cases like that, because different regions of the posterior would need different adaptation values. Usually reparametrization can help with those again. Usually you end up switching between “centered” and “non-centered” parametrizations. See for example https://twiecki.github.io/blog/2017/02/08/bayesian-hierchical-non-centered/
  • Very long tailed posteriors. Sometimes it can help in those cases to set a more informative prior.
  • Unidentifiable parameters (or nearly unidentifiable ones). You need a better model then.
  • Multimodality. This only slows down the sampler if you are lucky. In most cases the sampler just ends up near one mode and you might never fid out (always run several chains!!)
  • Model misspecification. This is probably the most common problem. If the logp is just wrong, nuts often gets stuck. This is a special case of the folk theorem: http://andrewgelman.com/2008/05/13/the_folk_theore/

Short story: There’s probably a problem with your model. Try plotting variables from the traces, and also look at scatter plots between the different variables.

6 Likes