 # Custom functions in Jax?

Sometimes some sort of computation function is needed for a use case and writing said functions in Theano is common. For example, the adstock function referenced in a 2017 Google paper on Media Mix Modeling:

``````import theano.tensor as tt
def geometric_adstock(x, theta, alpha,L=12):
w = tt.as_tensor_variable([tt.power(alpha,tt.power(i-theta,2)) for i in range(L)])
xx = tt.stack([tt.concatenate([tt.zeros(i), x[:x.shape -i]]) for i in range(L)])
return tt.dot(w/tt.sum(w), xx)

def saturation(x,s,k,b):
return b/(1 + (x/k)**(-s))
``````

And then this function could be used in PyMC3 models:

``````import arviz as az
import pymc3 as pm

with pm.Model() as m:
#var,      dist, pm.name,          params,  shape
alpha = pm.Beta('alpha',           3 , 3,   shape=X.shape) # retain rate in adstock
theta = pm.Uniform('theta',        0 , 12,  shape=X.shape) # delay in adstock
k     = pm.Beta('k',               2 , 2,   shape=X.shape) # "ec" in saturation, half saturation point
s     = pm.Gamma('s',              3 , 1,   shape=X.shape) # slope in saturation, hill coefficient
beta  = pm.Normal('beta',          0,  1,   shape=X.shape) # regression coefficient
tau    = pm.Normal('intercept',    0,  5                    ) # model intercept
noise = pm.InverseGamma('noise',   0.05,  0.005             ) # variance about y

computations = []
for idx,col in enumerate(X.columns):
comp = saturation(x=geometric_adstock(x=X[col].values,
alpha=alpha[idx],
theta=theta[idx],
L=12),
b=beta[idx],
k=k[idx],
s=s[idx])

computations.append(comp)

y_hat = pm.Normal('y_hat', mu= tau + sum(computations),
sigma=noise,
observed=y)

trace1 = pm.sample(chains=4)
``````

All that to say, could the Theano function defined in the first code block be replaced with a function written in Jax? I think so… but not certain!

I think I read somewhere that this is totally a thing, and numpy too, maybe?

Here is an example by @dfm: deterministic_op.ipynb · GitHub

1 Like