Sampling time GPU vs CPU

Hi :slight_smile:

  1. Slower sampling on a GPU is not at all unusual. GPUs are (unfortunately) not universally faster than CPUs. Their theoretical maximum throughput (floating point ops per second) is higher, but the actual speed depends a lot on what you are trying to compute, and how smart the GPU implementation is. GPUs are terrible at pretty much everything that doesn’t involve running the same independent computations on lot’s of different numbers. This tends to be exactly what people do in deep learning (lot’s of matrix vector products), which is why GPUs are used so much in that field, but quite often that is not the case in PyMC models. If the computation time is mostly due to large dense linear algebra operations (for instance in some GP models), then a GPU might perform well however. But since a model can contain pretty much arbitrary computations there really is no simple rule for when GPUs work well.

  2. Several reasons come to mind. float32 values need only half the number of bytes, so they can be transferred between main memory and the CPU faster. Often the smaller size can also mean that intermediate results might fit into the cache of the CPU with float32 but not with float64 values, which means slow memory access might not be needed in some cases. And SIMD operations in the CPU can work with twice as many values at the same time when the precision is smaller. Many GPUs have only limited support for float64, which might slow things down as well. Got to admit 14 min vs 4 is more than I would usually expect…

  3. PyMC is using PyTensor to represent the model as a computation graph. That graph can then be used to generate a PyTensor computation graph for the posterior log probability density and it’s derivative. We can then either compile that using the PyTensor C backend, or the numba backend, or we can translate it to jax. When using the jax backend PyTensor will just translate matrix vector multiplications to the corresponding jax functions, and jax will itself decide how that should be executed (often I think it will use BLAS).

I hope that helps a bit at least?

1 Like