Is PyMC capable of using forward-mode autodiff (JVP) from JAX?

Nothing else should be needed for forward diff in numpyro as that’s done by JAX (not PyTensor) from the logp graph returned by PyMC. You just have to try and pass that flag to the numpyro sampler