Auto-differentiation of user defined functions?

The HelloFresh engineering blog posted this tutorial. I apologize, they took a screenshot of the code, so a proper code block is not available. Nonetheless, it illustrates the point.

Of note, the authors defined saturation_function which is used three times in computation of mu.

My questions are:

  1. Is HMC/NUTS able to determine the gradient of this function and sample effectively?
  2. (If 1 is true) how does PyMC3 treat this function- Is Theano able to “see” the function and determine the gradient autonomously?
  3. (If 2 is false) will Jax be able to do this for PyMC3 in the future?


(1) No.
(2). Unsure - doesn’t look like it right now.

I’m a little confused as to how they got this to work. Maybe this is just a typo when they were writing the presentation since I would not expect them to directly show the .py. files they actually use in practice.

My understanding is that PyMC3 won’t be able to make use of the Numpy function. Also, while Jax can interface with Numpy, PyMC3’s experimental Jax support handles this at a different stage - it translates a Theano graph into a Jax one after PyMC3 and Theano have already finished structuring the model. I think the code below should work with their model.

import pymc3 as pm
def saturation_function(x_t, mu):
    return (1-pm.math.exp(-mu*x_t)) / (1 + pm.math.exp(-mu*x_t))
1 Like

@ckrapu, in the changes you’ve made, you simply changed the numpy functions with their PyMC3 built in equivalents. So it seems that PyMC3 is able to determine the gradient of user defined functions, so long as it doesn’t call any functions beyond “built in” python (and PyMC3, of course.) Pretty neat!

Sounds like the change to Jax will largely be invisible to users and more is more oriented around development “housekeeping”.

I wonder- is there a way to inspect what PyMC3 believes the gradient of a function to be (or if one has even been computed)?