Thanks @ricardoV94 and what if for that diffrax example in the blog post I wanted to use the forward mode (jvp) instead of reverse-mode (vjp)? Can PyMC use forward mode gradients?
(numpyro has an option to use forward mode: Markov Chain Monte Carlo (MCMC) — NumPyro documentation)