Does JAX backend Speedup also apply to variational inference?


I’ve been reading this post (MCMC for big datasets: faster sampling with JAX and the GPU - PyMC Labs) on how JAX backend (using numpyro) along with GPU can speed up NUTS sampler. But I was wondering, would that also apply to the variational inference algorithm in Pymc? Or currently we just have pymc.sampling.jax.sample_numpyro_nuts function for MCMC?

Thanks a lot,

Currently I think PyMC VI might still be incompatible with JAX backend. You can check by setting pytensor.config.mode=="JAX” (if you are on pymc>=5.0.0) before trying to fit a model.

If you install pymc-experimental you can use blackjax’s pathfinder: Pathfinder Variational Inference — PyMC example gallery

very appreciate for the information!