For a few days, I have been running some experiments concerning the sampling time of a rather big/complex model and I have some questions regarding what I found:

When using jax for the sampling (with pm.sampling_jax.sample_numpyro_nuts), I have a slower sampling on GPU than CPU (for one chain). For 1000 warmups and 1000 draws it takes around 14min on cpu and 19 min on GPU. I have checked the version of all my librairies and they look okay and the model is indeed using the GPU when asked to.
- Is there a way to boost the performances on GPU ?
- Is there a limit case (in terms on data set size or model complexity) under which GPU is less
performant than GPU ?

When using Float32 precision vs Float64 precision I notice a rather considerable speedup of my models (going from 14 to 4 mins on CPU and from 19 to 9 on GPU). What is the explanation concerning this speedup ?

I find it pretty difficult to understand what is happening under the hood with PyMC/PyTensor/JAX. Where can I find a good explanation of it ? For instance I am wondering in which case the model is using BLAS libraries ? (does PyTensor with JAX used BLAS for instance ? or what is run on cpu and what is run of gpu ?)

I know that those are specific question but I have been struggling for a while to find good answers.

Slower sampling on a GPU is not at all unusual. GPUs are (unfortunately) not universally faster than CPUs. Their theoretical maximum throughput (floating point ops per second) is higher, but the actual speed depends a lot on what you are trying to compute, and how smart the GPU implementation is. GPUs are terrible at pretty much everything that doesnâ€™t involve running the same independent computations on lotâ€™s of different numbers. This tends to be exactly what people do in deep learning (lotâ€™s of matrix vector products), which is why GPUs are used so much in that field, but quite often that is not the case in PyMC models. If the computation time is mostly due to large dense linear algebra operations (for instance in some GP models), then a GPU might perform well however. But since a model can contain pretty much arbitrary computations there really is no simple rule for when GPUs work well.

Several reasons come to mind. float32 values need only half the number of bytes, so they can be transferred between main memory and the CPU faster. Often the smaller size can also mean that intermediate results might fit into the cache of the CPU with float32 but not with float64 values, which means slow memory access might not be needed in some cases. And SIMD operations in the CPU can work with twice as many values at the same time when the precision is smaller. Many GPUs have only limited support for float64, which might slow things down as well. Got to admit 14 min vs 4 is more than I would usually expectâ€¦

PyMC is using PyTensor to represent the model as a computation graph. That graph can then be used to generate a PyTensor computation graph for the posterior log probability density and itâ€™s derivative. We can then either compile that using the PyTensor C backend, or the numba backend, or we can translate it to jax. When using the jax backend PyTensor will just translate matrix vector multiplications to the corresponding jax functions, and jax will itself decide how that should be executed (often I think it will use BLAS).