Is PyMC capable of using forward-mode autodiff (JVP) from JAX?

I’m going through this excellent tutorial for how to use gradients from JAX in PyMC: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

It says that PyMC uses reverse-mode autodiff (VJP) to get the gradients – the rows of the Jacobian – so that you can map perturbations in the outputs back to perturbations in the inputs.

Just curious – does PyMC have support for using forward-mode autodiff (JVP)? For a particular problem I’m working on in JAX, JVP works but VJP doesn’t yet (the latter gives all NaN’s) and until I track down why, I’m curious if I can somehow use JVP for PyMC inference.

Yes we have forward mode autodiff, it’s called R_op instead of L_op. You might have to investigate how does it map to JAX API.

General overview: Derivatives in PyTensor — PyTensor dev documentation

There’s some implementation info here: Creating a new Op: Python implementation — PyTensor dev documentation