Generalized LV: missing 1 required positional argument using pymc.ode.DifferentialEquation

Hi team,

I am trying to train a generalized Lotka Volterra using pymc.

I have a 500 x 30 dataset where there are 500 data points and 30 time series.

# Run only in the beginning
!pip install --force-reinstall numpy==1.26.4

# Import libraries
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt

from numba import njit
from pymc.ode import DifferentialEquation
from pytensor.compile.ops import as_op
from scipy.integrate import odeint, solve_ivp
from scipy.optimize import least_squares
from scipy.signal import detrend

import os

THEANO_FLAGS = "optimizer=fast_compile"

# Load priors 
# Shape (30, 1)
X0_din_lq = np.loadtxt(os.path.join(dinner_path, "X0_din_lq.csv"),
                       delimiter=",")

# Shape (30, 1)
r_din_lq = np.loadtxt(os.path.join(dinner_path, "r_din_lq.csv"),
                      delimiter=",")
# Shape (30, 30)
A_din_lq = np.loadtxt(os.path.join(dinner_path, "A_din_lq.csv"),
                      delimiter=",")

# Load time params for bayesian modeling
time = dinner_sampled['time'].values
t_span_din = (int(time[0]), int(time[-1]))
t_eval_din = np.linspace(*t_span_din, 500)

n = len(X0_din_lq)  # 30 time series
p = (2 * n) + (n ** 2) # number of features since we have X (30, 1), r (30, 1), A (30, 30)
X0 = dinner_wo_time.iloc[0, :].values

# dinner_wo_time is a pandas data frame with shape (500, 30) where 500 are the timepoints, 30 are the different time series. there is no "time" column hence the name.

# Solve ODE using the given parameters
def gLV(X, t, r, A):
  return X * (r + A @ X)

# Use PyMC's ODE solver
diffeq = DifferentialEquation(
    func=gLV,
    times=t_eval_din,
    n_states=n,  # Number of state variables
    n_theta=p,   # Number of parameters
    t0=t_span_din[0]
)

with pm.Model() as model:
    X0_prior = pm.Normal(name = "X0_prior",
                         mu = X0_din_lq,
                         sigma = 0.1,
                         initval = X0,
                         shape = X0.shape)

    r0_prior = pm.Normal(name = "r0_prior",
                         mu = r_din_lq,
                         sigma = 0.1,
                         initval = r_din_lq,
                         shape = r_din_lq.shape)

    A_prior = pm.Normal(name = "A_prior",
                        mu = A_din_lq,
                        sigma = 0.1,
                        initval = A_din_lq,
                        shape = A_din_lq.shape)

    ode_solution = diffeq(X0_prior, [r0_prior, A_prior])

    # Likelihood function
    sigma = pm.HalfNormal("sigma", sigma=2)
    pm.Normal("obs",
              mu=ode_solution,
              sigma=sigma,
              observed=dinner_wo_time.values)

Error:

/usr/local/lib/python3.11/dist-packages/pymc/ode/utils.py in augment_system(ode_func, n_states, n_theta)
    103 
    104     # Get symbolic representation of the ODEs by passing tensors for y, t and theta
--> 105     yhat = ode_func(t_y, t_t, t_p[n_states:])
    106     if isinstance(yhat, pt.TensorVariable):
    107         t_yhat = pt.atleast_1d(yhat)

TypeError: gLV() missing 1 required positional argument: 'A'

Would anyone know if this is just a misuse of the pymc.ode.DifferentialEquation class or is it because I am supplying a matrix prior? Thanks!