Hi team,
I am trying to get this model to work but I am getting issues getting the model to run.
- I have a dataset with 12 time series and 167 timepoints. I am trying to estimate the Lotka-Volterra parameters, which are X0 (shape = (12,)), r (shape = (12, )), and A (shape = (12, 12)). So this is around 160 parameters.
- I am creating PyTensor vectors and matrices for these with Uniform and Normal priors
#dinner_sampled is a 167 x 13 time series dataset, including one column labeled "time"
# Differential equation
def gLV(X, t, r, AX):
return X * (r + AX)
# decorator with input and output types a Pytensor double float tensors
@as_op(itypes=[pt.dmatrix], otypes=[pt.dmatrix])
def pytensor_forward_model_matrix(theta):
X0 = theta[0, :]
r = theta[1, :]
AX = theta[2, :]
return odeint(func=gLV, y0=X0,
t=dinner_sampled.time,
args=(r, AX,))
# theta_glv4 = results_glv4.x # least squares solution used to inform the priors
n = dinner_sampled.shape[1] - 1
with pm.Model() as model:
priors = []
# Params for X0
X0 = pm.Uniform("X0",
lower=0,
upper=120,
shape=n)
priors.append(X0)
# Params for r
r = pm.Normal("r",
mu=0,
sigma=2,
shape=n)
priors.append(r)
# Params for A
A = pm.Normal("A",
mu=0,
sigma=5,
shape=(n, n))
AX = pm.Deterministic(name = "AX", var = pm.math.dot(A, X0))
priors.append(AX)
# Ode solution function for Lotka-Volterra
ode_solution = pytensor_forward_model_matrix(
pm.math.stack(priors)
)
# Likelihood
pm.Normal("Y_obs",
mu=ode_solution,
sigma=sigma,
observed=dinner_sampled.iloc[:, 1:].values)
# Variable list to give to the sample step parameter
vars_list = list(model.values_to_rvs.keys())[:-1]
# Specify the sampler
sampler = "Slice Sampler"
tune = draws = 100
# Inference!
with model:
trace_slice = pm.sample(step=[pm.Slice(vars_list)],
draws=draws,
progressbar = True)
trace = trace_slice
az_trace = az.summary(trace)
az_trace
Hypothesis for slow run times:
- The “excess work done for this call” may imply the model is not specified correctly
- If not, would a different step function or sampler work?
- Would I need to increase the compute?
Thanks, any help would be appreciated!