Pymc3 / pymc v4 GPU example for a basic hierarchical model with NUTS sampling

Are there any clear-cut instructions anyone can point me to - to run a hierarchical pymc3 model with NUTS on the GPU?

I know that the sampler is on the CPU but I’ve seen that still matrix operations can be deployed on the GPU and in my use case that might offer a significant speed-up.

Alternatively, I might try the pymc v4 + JAX backend option but couldn’t find a clear example for running a basic model on the GPU.

Any pointers to the right material would be appreciated.


Did you find this post? That seems to be semi-up-to-date (e.g., as mentioned here).

1 Like