Custom functions in Jax?

Sometimes some sort of computation function is needed for a use case and writing said functions in Theano is common. For example, the adstock function referenced in a 2017 Google paper on Media Mix Modeling:

import theano.tensor as tt
def geometric_adstock(x, theta, alpha,L=12):
    w = tt.as_tensor_variable([tt.power(alpha,tt.power(i-theta,2)) for i in range(L)])
    xx = tt.stack([tt.concatenate([tt.zeros(i), x[:x.shape[0] -i]]) for i in range(L)])
    return tt.dot(w/tt.sum(w), xx)

def saturation(x,s,k,b):
    return b/(1 + (x/k)**(-s))

And then this function could be used in PyMC3 models:

import arviz as az
import pymc3 as pm

with pm.Model() as m:
    #var,      dist, pm.name,          params,  shape   
    alpha = pm.Beta('alpha',           3 , 3,   shape=X.shape[1]) # retain rate in adstock 
    theta = pm.Uniform('theta',        0 , 12,  shape=X.shape[1]) # delay in adstock
    k     = pm.Beta('k',               2 , 2,   shape=X.shape[1]) # "ec" in saturation, half saturation point
    s     = pm.Gamma('s',              3 , 1,   shape=X.shape[1]) # slope in saturation, hill coefficient
    beta  = pm.Normal('beta',          0,  1,   shape=X.shape[1]) # regression coefficient 
    tau    = pm.Normal('intercept',    0,  5                    ) # model intercept
    noise = pm.InverseGamma('noise',   0.05,  0.005             ) # variance about y 
    

    computations = []
    for idx,col in enumerate(X.columns):
        comp = saturation(x=geometric_adstock(x=X[col].values, 
                                                alpha=alpha[idx],
                                                theta=theta[idx],
                                                L=12),
                          b=beta[idx],
                          k=k[idx],
                          s=s[idx])
        
        computations.append(comp)

    
    y_hat = pm.Normal('y_hat', mu= tau + sum(computations),
                  sigma=noise, 
                  observed=y)
    
    trace1 = pm.sample(chains=4)

All that to say, could the Theano function defined in the first code block be replaced with a function written in Jax? I think so… but not certain!

I think I read somewhere that this is totally a thing, and numpy too, maybe?

Here is an example by @dfm: deterministic_op.ipynb · GitHub

1 Like