I am working with a version of the bayesian neural network model in this example notebook. I was wondering how one can change this to sample from the GPU using a JAX backend?
You would have to uses pm.sample(..., nuts_backend="numpyro")
whilst being on a machine with an available GPU.
So just so I understand correctly.
I will have to use an MCMC-based sampling approach and not the ADVI inference approach in the tutorial?
There’s is some working going on in this front here: Make VI compatible with JAX backend by ferrine · Pull Request #7103 · pymc-devs/pymc · GitHub