Is it feasible to implement a convenience option for this for the NUTS sampler? I noticed numpyro has a runtime argument called “forward_mode_differentiation”: Markov Chain Monte Carlo (MCMC) — NumPyro documentation
I think for now I will try to learn numpyro for sampling JAX-based code but just curious.