Hello,
I’m trying to improve the sampling efficiency of my model and just have a few general questions that I’m hoping someone more knowledgeable with the foundations of PyMC can help me with. The current estimated runtime for my model is around 24 days, so even relatively small percent speed-ups would be great to know about.
- Is there a difference between using the Numpyro or BlackJAX backends? Is there a reason to try both to see which is fastest or should they be about equivalent?
- I analyzed the profiling using
model.profile(...)
. Do the recommendations at the bottom only have an impact on sampling with the default NUTS sampler, or should they affect the JAX backends, too?
- Is there an efficiency (e.g. better memory management, faster compute) reason to use
pm.ConstantData()
objects instead of just using Numpy arrays?
Thank you in advance for any help!
1 Like
Blackjax is a bit more experimental than numpyro.
The recommendations from profile probably will not be (as) relevant for the JAX backend
You can usually get an automatic 2x speedup by running everything in float32 instead of float64, if you are okay with the loss in numerical precision. To do this, immediately after importing aesara execute aesara.config.floatX = "float32"
Otherwise it’s difficult to give any tips without having an idea of the type of model/data involved. Sometimes simple parametrization changes can make a large difference, other times you need custom sampler/approximations or write low level code for the bottleneck computations.
2 Likes
How good is your convergence? That usually makes a huge difference in sampling time.
You can usually get an automatic 2x speedup by running everything in float32 instead of float64, if you are okay with the loss in numerical precision. To do this, immediately after importing aesara execute aesara.config.floatX = "float32"
I tried doing this using the flag environment variable and set the warnings for downcasting of floats, too: AESARA_FLAGS='floatX=float32, warn_float64=warn'
. I got a lot of warnings about float64
variables being created. Should I spend time to figure out where they are coming from because they’ll slow the sampling down or are they pretty quickly casted down to 32 bit and I need not worry about them?
Otherwise it’s difficult to give any tips without having an idea of the type of model/data involved. Sometimes simple parametrization changes can make a large difference, other times you need custom sampler/approximations or write low level code for the bottleneck computations.
How good is your convergence? That usually makes a huge difference in sampling time.
Convergence is mediocre. There are a few hierarchical variables with thousands of parameters each so the sampling efficiency is often poor (low ESS and higher autocorrelation). I’m also using a covariance structure (non-centered parameterization) on some of these variables and find that usually slows down sampling.
More generally, I feel like I’m very comfortable with the basics of PyMC and Bayesian modeling (particularly GLMs). Do you have any recommendations on how to learn more advanced topics, both specific to PyMC and more generally on statistical modeling? (I realize that this is perhaps beyond the scope of the original question.)
Do you have any discrete variables? Those tend to stubbornly introduce float64 when float32 are mixed with int64 (which is the default type for discrete variables)
Also some transforms like simplex currently introduce float64 (Simplex transform upcasts float32 input to float64 · Issue #91 · aesara-devs/aeppl · GitHub)
Otherwise make sure your observed data is already in in float32. In terms of performance I think JAX just truncates any float64 when running in float32. So it shouldn’t hurt performance much but the warnings are not pretty either
Thank you for your help, I’ll play around a bit more to see what’s causing the warnings.
I think the main problem with my models, like @twiecki suggested, is convergence. I tend to have more problems when I scale up the model with more groups in the hierarchical variables (i.e. the coding of the model doesn’t change, but using larger data sets brings in more groups from 10-100’s to 1,000-10,000’s). Is this expected? Naively, I would expect the sampler to take longer on each step because there are more computations, but the chain should still explore the space as easily as when it was lower dimensional. Is this thinking correct or are there special precautions I should take when using very large hierarchical variables?
Thank you again for your help and patience. I’m trying to learn more general principles, so I appreciate that these questions are harder to answer than a specific bug in a model.
With larger data de posterior may become very peaked and gradients unstable, which NUTS could struggle with.
Also from experience (but don’t quote me on that), the higher the data volume the worse a wrong model will perform (assuming we are talking about a complex enough model that does more than say fit group averages).
For very large datasets we sometimes discretize the likelihood but more for speed reasons than convergence issues I think. @ferrine could probably tell you more about that.
With larger data de posterior may become very peaked and gradients unstable, which NUTS could struggle with.
Also from experience (but don’t quote me on that), the higher the data volume the worse a wrong model will perform (assuming we are talking about a complex enough model that does more than say fit group averages).
I would agree with that assessment, which makes diagnosing the problems difficult because fitting the larger models takes so much longer.
For very large datasets we sometimes discretize the likelihood but more for speed reasons than convergence issues I think. @ferrine could probably tell you more about that.
Can you expand on “discretize the likelihood?” I’m not sure what that means in practice.
Also, I sped up the sampling for my model a lot by adjusting the priors based on the prior predictive distributions. My priors were way too wide; I wasn’t accounting for the exponential link function and exposure measurement for my negative binomial regression model. In the end, I think most of my issues were model mis-specification due to poor priors.
(I also got massive speed-ups by using a GPU. Examples of reduced fitting times: from 85 hours down to 2 hours and from over 300 hours to 10 hours!)
1 Like
He may be referring to this, see tweet here.
2 Likes