Currently I think PyMC VI might still be incompatible with JAX backend. You can check by setting pytensor.config.mode=="JAX” (if you are on pymc>=5.0.0) before trying to fit a model.
If you install pymc-experimental you can use blackjax’s pathfinder: Pathfinder Variational Inference — PyMC example gallery