Is PyMC capable of using forward-mode autodiff (JVP) from JAX?

Yes we have forward mode autodiff, it’s called R_op instead of L_op. You might have to investigate how does it map to JAX API.

General overview: Derivatives in PyTensor — PyTensor dev documentation

There’s some implementation info here: Creating a new Op: Python implementation — PyTensor dev documentation