Sometimes some sort of computation function is needed for a use case and writing said functions in Theano is common. For example, the adstock function referenced in a 2017 Google paper on Media Mix Modeling:
import theano.tensor as tt
def geometric_adstock(x, theta, alpha,L=12):
w = tt.as_tensor_variable([tt.power(alpha,tt.power(i-theta,2)) for i in range(L)])
xx = tt.stack([tt.concatenate([tt.zeros(i), x[:x.shape[0] -i]]) for i in range(L)])
return tt.dot(w/tt.sum(w), xx)
def saturation(x,s,k,b):
return b/(1 + (x/k)**(-s))
And then this function could be used in PyMC3 models:
import arviz as az
import pymc3 as pm
with pm.Model() as m:
#var, dist, pm.name, params, shape
alpha = pm.Beta('alpha', 3 , 3, shape=X.shape[1]) # retain rate in adstock
theta = pm.Uniform('theta', 0 , 12, shape=X.shape[1]) # delay in adstock
k = pm.Beta('k', 2 , 2, shape=X.shape[1]) # "ec" in saturation, half saturation point
s = pm.Gamma('s', 3 , 1, shape=X.shape[1]) # slope in saturation, hill coefficient
beta = pm.Normal('beta', 0, 1, shape=X.shape[1]) # regression coefficient
tau = pm.Normal('intercept', 0, 5 ) # model intercept
noise = pm.InverseGamma('noise', 0.05, 0.005 ) # variance about y
computations = []
for idx,col in enumerate(X.columns):
comp = saturation(x=geometric_adstock(x=X[col].values,
alpha=alpha[idx],
theta=theta[idx],
L=12),
b=beta[idx],
k=k[idx],
s=s[idx])
computations.append(comp)
y_hat = pm.Normal('y_hat', mu= tau + sum(computations),
sigma=noise,
observed=y)
trace1 = pm.sample(chains=4)
All that to say, could the Theano function defined in the first code block be replaced with a function written in Jax? I think so… but not certain!