Does JAX backend Speedup also apply to variational inference?

Currently I think PyMC VI might still be incompatible with JAX backend. You can check by setting pytensor.config.mode=="JAX” (if you are on pymc>=5.0.0) before trying to fit a model.

If you install pymc-experimental you can use blackjax’s pathfinder: Pathfinder Variational Inference — PyMC example gallery