PyMC v4 sampling efficiency questions

Hello,
I’m trying to improve the sampling efficiency of my model and just have a few general questions that I’m hoping someone more knowledgeable with the foundations of PyMC can help me with. The current estimated runtime for my model is around 24 days, so even relatively small percent speed-ups would be great to know about.

  1. Is there a difference between using the Numpyro or BlackJAX backends? Is there a reason to try both to see which is fastest or should they be about equivalent?
  2. I analyzed the profiling using model.profile(...). Do the recommendations at the bottom only have an impact on sampling with the default NUTS sampler, or should they affect the JAX backends, too?
  3. Is there an efficiency (e.g. better memory management, faster compute) reason to use pm.ConstantData() objects instead of just using Numpy arrays?

Thank you in advance for any help!

1 Like

Blackjax is a bit more experimental than numpyro.

The recommendations from profile probably will not be (as) relevant for the JAX backend

You can usually get an automatic 2x speedup by running everything in float32 instead of float64, if you are okay with the loss in numerical precision. To do this, immediately after importing aesara execute aesara.config.floatX = "float32"

Otherwise it’s difficult to give any tips without having an idea of the type of model/data involved. Sometimes simple parametrization changes can make a large difference, other times you need custom sampler/approximations or write low level code for the bottleneck computations.

2 Likes

How good is your convergence? That usually makes a huge difference in sampling time.

You can usually get an automatic 2x speedup by running everything in float32 instead of float64, if you are okay with the loss in numerical precision. To do this, immediately after importing aesara execute aesara.config.floatX = "float32"

I tried doing this using the flag environment variable and set the warnings for downcasting of floats, too: AESARA_FLAGS='floatX=float32, warn_float64=warn'. I got a lot of warnings about float64 variables being created. Should I spend time to figure out where they are coming from because they’ll slow the sampling down or are they pretty quickly casted down to 32 bit and I need not worry about them?

Otherwise it’s difficult to give any tips without having an idea of the type of model/data involved. Sometimes simple parametrization changes can make a large difference, other times you need custom sampler/approximations or write low level code for the bottleneck computations.

How good is your convergence? That usually makes a huge difference in sampling time.

Convergence is mediocre. There are a few hierarchical variables with thousands of parameters each so the sampling efficiency is often poor (low ESS and higher autocorrelation). I’m also using a covariance structure (non-centered parameterization) on some of these variables and find that usually slows down sampling.

More generally, I feel like I’m very comfortable with the basics of PyMC and Bayesian modeling (particularly GLMs). Do you have any recommendations on how to learn more advanced topics, both specific to PyMC and more generally on statistical modeling? (I realize that this is perhaps beyond the scope of the original question.)

Do you have any discrete variables? Those tend to stubbornly introduce float64 when float32 are mixed with int64 (which is the default type for discrete variables)

Also some transforms like simplex currently introduce float64 (Simplex transform upcasts float32 input to float64 · Issue #91 · aesara-devs/aeppl · GitHub)

Otherwise make sure your observed data is already in in float32. In terms of performance I think JAX just truncates any float64 when running in float32. So it shouldn’t hurt performance much but the warnings are not pretty either :slight_smile: