PyMC3 slows rapidly with increasing numbers of parameters

I am trying to use PyMC3 to fit the spectra of galaxies. The model I use to fit the spectra is currently described by four parameters. At present, I am trying to fit simulated spectra (i.e., data) to assess (a) how reliably PyMC3 is able to constrain the known model parameters and (b) how quickly it converges.

All the parameters in my model are continuous, so I’m using the NUTS sampler. When I only fit a single parameter (i.e., fix the other three at the known truth values) the sampler runs quickly (advi stage at ~5000 it/s, sampling stage at ~1000 it/s). However, the rate of sampling falls dramatically as I add more free parameters. For example, when sampling all four parameters the advi stage still runs fairly quickly (at ~3000 it/s), but the sampling stage falls to ~20 it/s.

By contrast, if I use Metropolis to sample all four parameters, I get a sampling rate of ~1000 it/s.

The fact that it samples quickly with a single free parameter and Metropolis suggests to me that the low sampling rate is not due to each model evaluation taking a long time. Instead, I was wondering whether it could be related to NUTS’s gradient calculation.

I was just wondering whether this sounds like normal behaviour (i.e., slow sampling with just four model parameters) and if you have any advice on how I could speed things up (I’d like to use this to fit the spectra of ~tens of thousands of galaxies, so speed is fairly important).

Code available at:
https://github.com/SheffAGN/SEDist (see test.py).

Thanks for your help, and keep up the good work!

1 Like

I haven’t look in detail at the model yet, but I can try to give you a couple of general hints.
How quickly NUTS samples depends less on the number of parameters, but more on the shape of the posterior. If the posterior is reasonably close to a normal without correlations, then 10000 parameters aren’t a problem. But if your posterior has certain features, things get much more difficult really quickly. Comparing sampling speed between NUTS and Metropolis is usually pretty pointless, Metropolis just doesn’t notice that its samples are bad and will happily give you tons of useless samples.
Things that make sampling hard, in no particular order:

  • Very different posterior variance in different variables. Most of the time we get around that using advi (we only use the result to rescale the variables), but in some cases advi gives stupid answers and then nuts gets into trouble. You can manually call pm.fit to have a look at what results you get and if they match the trace somewhat. (If you feel very adventurous you could also use the work in progess branch https://github.com/pymc-devs/pymc3/pull/2327, that adapts scaling after advi. Update this is no longer work-in-progress, since 3.2 this is the default)
  • Correlated variables. In low dimensions this isn’t usually that bad, but it can still slow down sampling. In cases like that you could try to find a parametrization such that the correlations in the posterior are smaller, or if n is small, you could use full-rank advi for initialization (the init param of pm.sample, 'nuts' also sets a full mass matrix to get rid of correlations)
  • “funnels”. That often happens with scale parameters in hierarchical models. The posterior of the other variables given a small value for a scale parameter might be very different that the posterior of the other variables given a large scale scale parameter. The scaling adaptation can’t work in cases like that, because different regions of the posterior would need different adaptation values. Usually reparametrization can help with those again. Usually you end up switching between “centered” and “non-centered” parametrizations. See for example https://twiecki.github.io/blog/2017/02/08/bayesian-hierchical-non-centered/
  • Very long tailed posteriors. Sometimes it can help in those cases to set a more informative prior.
  • Unidentifiable parameters (or nearly unidentifiable ones). You need a better model then.
  • Multimodality. This only slows down the sampler if you are lucky. In most cases the sampler just ends up near one mode and you might never fid out (always run several chains!!)
  • Model misspecification. This is probably the most common problem. If the logp is just wrong, nuts often gets stuck. This is a special case of the folk theorem: http://andrewgelman.com/2008/05/13/the_folk_theore/

Short story: There’s probably a problem with your model. Try plotting variables from the traces, and also look at scatter plots between the different variables.

6 Likes

Not that I think this is actually the problem, but if you want to scale this, you’ll probably need this at some point: You can get information about how long a leapfrog step takes using something like this. (with no stability guaranties, we might change that in the near future)

N = model.ndim
with model:
    step = pm.step_methods.hmc.base_hmc.BaseHMC(
        scaling=np.ones(N), profile=False, use_single_leapfrog=True)

np.random.seed(42)
q = floatX(np.random.randn(N))
p = floatX(np.random.randn(N))
epsilon = floatX(np.array(0.01))
q_grad = model.dlogp_array(q)
%timeit _ = step.leapfrog(q, p, q_grad, epsilon)

If you set profile=True you can get a summary about what theano is doing by step.leapfrog.profile.summary(). Or you can get a graphical representation with theano.printing.pydotprint(step.leapfrog) (which will probably be somewhat large…), or

import theano.d3viz as d3v
d3v.d3viz(step.leapfrog, 'profile.html')

I had a quick look at your model. My suggestion (besides all the great general advice from @aseyboldt) is to change the Uniform prior (https://github.com/SheffAGN/SEDist/blob/master/test.py#L43-L46) to a weakly informed prior. You can see some recommendation from Stan here. In general a weakly informed prior works much better than a bounded flat prior (i.e., Uniform).

Also, another thing to keep in mind for reparameterization is to try to scale all your parameters to the same scale (ideally zero mean as well). It helps ADVI and also NUTS to get to the typical set. For example, in your model plnorm is at a much smaller scale than tp and temp, which might results in slow down during tunning (although after tunning it should be fine).

Hope you find this helpful!

Thanks, both. Your advice is indeed very informative and has given me a much greater insight into what affects the behaviour of NUTS under the hood.

I can certainly see a number of @aseyboldt 's points being relevant to my model. My variables are, indeed, correlated and, in some parts of parameter space, unidentifiable (but where this occurs depends on the values of other variables). I’ll attempt to reparameterise and rescale my model parameters (which may require introducing hyperparameters) to get rid of correlated parameters.

I had tried some weakly informative (Normal) priors in an earlier version, but the sampler seemed then to often get stuck in a small part of parameter space far from the truth. However, this could be solved with the aforementioned changes, so will try weakly informative priors again after reparameterising and rescaling.

Finally, thanks for the advice on timing the leapfrog step. Hopefully I’ll get to the point where I can scale the model!

Is there an updated snippet using pymc3 3.5 to calculate the time spent calculating trajectory?

1 Like