Sometimes some sort of computation function is needed for a use case and writing said functions in Theano is common. For example, the adstock function referenced in a 2017 Google paper on Media Mix Modeling:
import theano.tensor as tt def geometric_adstock(x, theta, alpha,L=12): w = tt.as_tensor_variable([tt.power(alpha,tt.power(i-theta,2)) for i in range(L)]) xx = tt.stack([tt.concatenate([tt.zeros(i), x[:x.shape -i]]) for i in range(L)]) return tt.dot(w/tt.sum(w), xx) def saturation(x,s,k,b): return b/(1 + (x/k)**(-s))
And then this function could be used in PyMC3 models:
import arviz as az import pymc3 as pm with pm.Model() as m: #var, dist, pm.name, params, shape alpha = pm.Beta('alpha', 3 , 3, shape=X.shape) # retain rate in adstock theta = pm.Uniform('theta', 0 , 12, shape=X.shape) # delay in adstock k = pm.Beta('k', 2 , 2, shape=X.shape) # "ec" in saturation, half saturation point s = pm.Gamma('s', 3 , 1, shape=X.shape) # slope in saturation, hill coefficient beta = pm.Normal('beta', 0, 1, shape=X.shape) # regression coefficient tau = pm.Normal('intercept', 0, 5 ) # model intercept noise = pm.InverseGamma('noise', 0.05, 0.005 ) # variance about y computations =  for idx,col in enumerate(X.columns): comp = saturation(x=geometric_adstock(x=X[col].values, alpha=alpha[idx], theta=theta[idx], L=12), b=beta[idx], k=k[idx], s=s[idx]) computations.append(comp) y_hat = pm.Normal('y_hat', mu= tau + sum(computations), sigma=noise, observed=y) trace1 = pm.sample(chains=4)
All that to say, could the Theano function defined in the first code block be replaced with a function written in Jax? I think so… but not certain!