I’m relatively new to PyMC3, having spent a bit of time using Stan for various Bayesian modeling tasks. When you really want to go “offroading” and build a semi-novel architecture, you have to have to account for untracked changes to the Jacobian for HMC/NUTS to sample correctly. Computing gradients by hand isn’t impossible, but it’s an error prone task.
I’ve heard PyMC3 has something analogous where you pass untracked gradient information over to Theano to ensure correct HMC sampling. (True?)
I’m most interested in the switch of PyMC3’s backend switch to Jax, where functions, loops, etc can be differentiated. It seems that one would never need to compute a gradient “by hand”; Jax would internally solve gradient formula on its own, even if a model was novel to PyMC3.
Am I overly optimistic of the Jax/PyMC3 future or is this accurate?