Jax sampling for bayesian neural network

I am working with a version of the bayesian neural network model in this example notebook. I was wondering how one can change this to sample from the GPU using a JAX backend?

You would have to uses pm.sample(..., nuts_backend="numpyro") whilst being on a machine with an available GPU.

So just so I understand correctly.
I will have to use an MCMC-based sampling approach and not the ADVI inference approach in the tutorial?

There’s is some working going on in this front here: Make VI compatible with JAX backend by ferrine · Pull Request #7103 · pymc-devs/pymc · GitHub