I’m going through this excellent tutorial for how to use gradients from JAX in PyMC: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs
It says that PyMC uses reverse-mode autodiff (VJP) to get the gradients – the rows of the Jacobian – so that you can map perturbations in the outputs back to perturbations in the inputs.
Just curious – does PyMC have support for using forward-mode autodiff (JVP)? For a particular problem I’m working on in JAX, JVP works but VJP doesn’t yet (the latter gives all NaN’s) and until I track down why, I’m curious if I can somehow use JVP for PyMC inference.