PyMC runtime question (Jax, BLAS, aesara)


First of all, thanks a lot for this marvellous library :slight_smile: !

I have started using PyMC (v4) a month ago and have run several experiments to try to speed up my models as much as possible.
I have several questions concerning the internal functioning of PyMC to understand my results and I cannot find a clear answer in the PyMC documentation. Here are my questions:

  1. When I run “pm.sample” (with the default arguments) does PyMC uses Jax or does aesare compile the model to C ?

  2. I have noticed a significant speed up of my model when I use PyMC (pm.sample) versus PyStan and I have noticed that several CPU cores are used for a single chain. Is this behavior related to the BLAS librairies ? If yes, how does it work in practice ? And is the behavior controlled by Jax, aesara ? I am not sure to understand what is happening under the hood. What does BLAS in practice ?

  3. In the documentation it is written that the model can be compiled to jax. If I use “pm.sample” with jax installed, does it mean that the model will run on jax and the sampler on C ?

  4. When I run this “jx.sample_numpyro_nuts”, does this mean that the sampling will also leverage Jax ? I am not 100% sure of understanding the asset of using this.

  5. When I run “jx.sample_numpyro_nuts” on a GPU, am I right to assume that the sampling happends on the CPU and the computation (matrix operations, gradient computation etc.) use the GPU ?

  6. When I run the model on a GPU I also notice a signifcant speed up when I use float32 instead of float64, is it due to the fact that it is lighter to transfer information from the CPU to GPU ?

I now this is a lot of questions for a message but I really could use some clarifications and I hope those questions could help others too :slight_smile: !

Many thanks in advance !


These are all great questions

In that case the model logp and dlogp are compiled into the default Aesara backend which is C. You can change the default by setting aesara.config.mode = ”JAX" (or “NUMBA”). Sampling happens with the PyMC pure python NUTS algorithm.

When you compile to C with Aesara, some models will automatically rely on BLAS for faster computation. BLAS operations (most common GEMM) can use multiprocessing for speedup. You can read more about it in Aesara docs here: Multi cores support in Aesara — Aesara 2.8.6+12.g69c10443b.dirty documentation

No. As mentioned on the first reply, by default your model will be compiled to C and run on a Python sampler. You can change the default mode manually to JAX, but calling pm.sample will still use the python NUTS sampler. Only the model logp and dlogp backend changes.

Yes. The numpyro nuts sampler is completely written in JAX instead of python. This can provide great speedup because NUTS is basically a bunch of nested loops which Python is notoriously slow at. It’s also possible that XLA provides some other advantages for NUTS algorithms but I don’t know it. There’s also a sampler written in rust by one of PyMC code developers that shows the same sort of speedups, probably for the same reasons: GitHub - aseyboldt/nutpie: Python wrapper for nuts-rs

I think everything happens on the GPU (or CPU if that’s the default you set for JAX) but I could be wrong. It might also be possible to configure it at runtime via some kwargs to sample_numpyro_nuts. You would have to check.

Most GPUs operate natively in float32 and are therefore much faster when working with float32. This is a well known behavior. Some TPUs go even further and get speedups from operating on float16. Deep learning folk say they don’t need the extra precision. It’s unclear if that’s the case for Bayesian inference.