PyMC3, Jax, and analytically computing gradients for novel model architectures

I am not sure if I understand what you refer to as “novel architecture” - gradient computation in PPL like PyMC3 and Stan relies on an autodiff framework, for Stan this is their math library, for PyMC3 it is theano. Other PPL choose to build on top of other autodiff libraries like TF, Jax, PyTorch, and each with its pros and cons.
Jacobian correction for change of variable is strictly apply to domain changes for Random Variable during sampling - In another word, when you apply function on RV explicitly, the volume changes are not accounted for.
You can build PPL that accounted for that, and build volume preserved function transformation. This will be available in PyMC3 once we switch to RandomVariableTensor implementation in theano (cc @brandonwillard). Currently, the only other framework can do that is Oryx from TFP does), where it taps into the Jax execution trace and build function transformation on top of that. Instead, the PyMC3 RandomVariableTensor can do these in the theano level and compile to Jax or other backend (when it is fully implemented).

Likely we will do that in theano, but you can already do that now (which is also how the experimental Jax sampler works)

1 Like