PyMC3, Jax, and analytically computing gradients for novel model architectures

I’m relatively new to PyMC3, having spent a bit of time using Stan for various Bayesian modeling tasks. When you really want to go “offroading” and build a semi-novel architecture, you have to have to account for untracked changes to the Jacobian for HMC/NUTS to sample correctly. Computing gradients by hand isn’t impossible, but it’s an error prone task.

I’ve heard PyMC3 has something analogous where you pass untracked gradient information over to Theano to ensure correct HMC sampling. (True?)

I’m most interested in the switch of PyMC3’s backend switch to Jax, where functions, loops, etc can be differentiated. It seems that one would never need to compute a gradient “by hand”; Jax would internally solve gradient formula on its own, even if a model was novel to PyMC3.

Am I overly optimistic of the Jax/PyMC3 future or is this accurate?

I am not sure if I understand what you refer to as “novel architecture” - gradient computation in PPL like PyMC3 and Stan relies on an autodiff framework, for Stan this is their math library, for PyMC3 it is theano. Other PPL choose to build on top of other autodiff libraries like TF, Jax, PyTorch, and each with its pros and cons.
Jacobian correction for change of variable is strictly apply to domain changes for Random Variable during sampling - In another word, when you apply function on RV explicitly, the volume changes are not accounted for.
You can build PPL that accounted for that, and build volume preserved function transformation. This will be available in PyMC3 once we switch to RandomVariableTensor implementation in theano (cc @brandonwillard). Currently, the only other framework can do that is Oryx from TFP does), where it taps into the Jax execution trace and build function transformation on top of that. Instead, the PyMC3 RandomVariableTensor can do these in the theano level and compile to Jax or other backend (when it is fully implemented).

Likely we will do that in theano, but you can already do that now (which is also how the experimental Jax sampler works)

1 Like

@junpenglao thanks for the amplifying info! When you say:

when you apply function on RV explicitly, the volume changes are not accounted for.

Is this issue introduced anytime pm.Deterministic is used?
I often need to add or subtract random variables in sampling, IRT as an example. I wasn’t aware that this would/should preclude use of NUTS, but it seems this is a more serious possibility based on your feedback. Thoughts?

As long as the random variables are continues, NUTS will work fine, including those cases you have a function in the model that maps the variables to discrete space (well, HMC/NUTS still works but for those dimension the gradient is likely missing so it will be some kind of random walk).
So no deterministic wont preclude use of HMC/NUTS.

1 Like