Is PyMC capable of using forward-mode autodiff (JVP) from JAX?

Is it feasible to implement a convenience option for this for the NUTS sampler? I noticed numpyro has a runtime argument called “forward_mode_differentiation”: Markov Chain Monte Carlo (MCMC) — NumPyro documentation

I think for now I will try to learn numpyro for sampling JAX-based code but just curious.