PSA: pm.sample now has full integration with Numba backend

With the latest release of PyMC 5.19.0: Release v5.19.0 · pymc-devs/pymc · GitHub it is now possible to pass compile_kwargs=dict(mode="NUMBA") to pm.sample to use the Numba backend with the PyMC samplers. For some models this can lead to big speedups in sampling (although possibly offset by longer compilation times).

For the eight school model benchmarked here: Speedup `sample` and allow specifying `compile_kwargs` (several major changes related to step samplers) by ricardoV94 · Pull Request #7578 · pymc-devs/pymc · GitHub sampling was roughly 2x faster in the Numba backend than the default C backend.

Sampling with the default backend should itself also be roughly 2x faster than in previous versions, due to some optimizations we did. In case you need an excuse to update…

And if you are already willing to play with the Numba backend, you can also install nutpie and try nuts_sampler="nutpie". Due to smarter initialization and full focus on NUTS sampling, it can run in a breeze compared to the PyMC NUTS sampler. GitHub - pymc-devs/nutpie: Python wrapper for nuts-rs

Let us know if you have any questions or find any problems!

14 Likes

Ok potentially silly question, if we’re using nuts_sample="nutpie" should we also be using compile_kwargs=dict(mode="NUMBA") along with it?

1 Like

Not a silly question. You don’t need to and it will have no impact.

Using nutpie reduced our sampling time by almost 1/2. Kudos to the Rust developers who refactored Nuts.

3 Likes

Credit for nutpie goes to @aseyboldt (plus community work on the PyTensor numba backend)

Also nutpie can often get away with shorter tuning so you may be able to save even more runtime

2 Likes

Previously, it took around one day to estimate a hierarchical model with 200k observations using numpyro_nuts. However, by switching to nutpie and numba, (it was done on the same recent processor from the Apple family and the properly configured Accelerate framework for both models) I was able to reduce the runtime to less than 11 minutes.

Perhaps I encountered suboptimal geometry for the numpyro_nuts or misspecified the model, but this improvement is truly impressive. Thank you @ricardoV94 and rest of the Pymc team for all the hard work that made this possible!

5 Likes