Is PyMC capable of using forward-mode autodiff (JVP) from JAX?

I’m going through this excellent tutorial for how to use gradients from JAX in PyMC: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

It says that PyMC uses reverse-mode autodiff (VJP) to get the gradients – the rows of the Jacobian – so that you can map perturbations in the outputs back to perturbations in the inputs.

Just curious – does PyMC have support for using forward-mode autodiff (JVP)? For a particular problem I’m working on in JAX, JVP works but VJP doesn’t yet (the latter gives all NaN’s) and until I track down why, I’m curious if I can somehow use JVP for PyMC inference.

Yes we have forward mode autodiff, it’s called R_op instead of L_op. You might have to investigate how does it map to JAX API.

General overview: Derivatives in PyTensor — PyTensor dev documentation

There’s some implementation info here: Creating a new Op: Python implementation — PyTensor dev documentation

Is it feasible to implement a convenience option for this for the NUTS sampler? I noticed numpyro has a runtime argument called “forward_mode_differentiation”: Markov Chain Monte Carlo (MCMC) — NumPyro documentation

I think for now I will try to learn numpyro for sampling JAX-based code but just curious.

Nothing else should be needed for forward diff in numpyro as that’s done by JAX (not PyTensor) from the logp graph returned by PyMC. You just have to try and pass that flag to the numpyro sampler