Hello,
I’ve been reading this post (MCMC for big datasets: faster sampling with JAX and the GPU - PyMC Labs) on how JAX backend (using numpyro) along with GPU can speed up NUTS sampler. But I was wondering, would that also apply to the variational inference algorithm in Pymc? Or currently we just have pymc.sampling.jax.sample_numpyro_nuts function for MCMC?
Thanks a lot,
Frank