Are there any clear-cut instructions anyone can point me to - to run a hierarchical pymc3 model with NUTS on the GPU?
I know that the sampler is on the CPU but I’ve seen that still matrix operations can be deployed on the GPU and in my use case that might offer a significant speed-up.
Alternatively, I might try the pymc v4 + JAX backend option but couldn’t find a clear example for running a basic model on the GPU.
Any pointers to the right material would be appreciated.