Fitting a hierarchical ODE with Sunode


I’m attempting to fit a hierarchical ODE model using sunode and I’m struggling with the best way to get all the dimensions to agree for sampling. In this toy example, I’m fitting measured drug amounts across multiple studies to a one-compartment pharmacokinetic model. However, the total number of measurements and the time when a measurement is taken will differ between the studies. To future-proof my code, I’m using pymc v.4 and aesara for this example.

Here is an example of some toy data that I’m trying to fit.

import pymc as pm
import numpy as np
import sunode
from sunode.wrappers.as_theano import solve_ivp
import matplotlib.pyplot as plt
import aesara.tensor as at
import aesara

# Model to fit
def one_cmpt(t, y, p):
    return {
        'A_expo': -p.k_abs*y.A_expo,
        'A_cent': p.k_abs*y.A_expo - p.k_elim*y.A_cent,

# Create the data
SEED = 54321
DOSE = 1

problem = sunode.symode.SympyProblem(
        'k_abs': (),
        'k_elim': (),
    states = {
        'A_expo': (),
        'A_cent': (),
    rhs_sympy = one_cmpt,

solver = sunode.solver.Solver(problem, solver='BDF')
tvals = np.arange(0, 26, 0.5)
y0 = np.zeros((), dtype=problem.state_dtype)
y0['A_expo'] = DOSE
y0['A_cent'] = 0

    'k_abs': 0.5,
    'k_elim': 0.2,

output = solver.make_output_buffers(tvals)
solver.solve(t0=0, tvals=tvals, y0=y0, y_out=output)
amt = output.view(problem.state_dtype)['A_cent']
ln_amt = np.log(amt)

# Create reported data points from each study
sigma = 0.3

tvals0 = np.array([1, 2, 4, 6, 12, 24]) # Times measured in hypothetical study 1
tvals1 = np.array([0.5, 2, 3, 6, 18]) # Times measured in hypothetical study 2
sidx0 = np.where(np.in1d(tvals, tvals0))
sidx1 = np.where(np.in1d(tvals, tvals1))

# Assume log normal distribution of data
ln_amt0 = rng.normal(ln_amt[sidx0], sigma)
ln_amt1 = rng.normal(ln_amt[sidx1], sigma)

# Plot the output
plt.plot(tvals, amt, label='True soln', color='blue')
plt.plot(tvals0, np.exp(ln_amt0), 'go', label='Study 1')
plt.plot(tvals1, np.exp(ln_amt1), 'ro', label='Study 2')
plt.xlabel('Time [hrs]')
plt.ylabel('Measured amount [mg]')


Each point in this plot represents a reported time-course amount measured in each hypothetical study and I’d like to fit the ODE parameters (k_abs, k_elim) using a hierarchical model. However, study 1 took 6 measurements while study 2 took 5 measurements. To account for discrepancies in total measurements across studies, I attempt to flatten the predicted amounts (y_hat['A_cent']) from the ODE using the known indices (sidx0 and sidx1) for each study and at.concatenate. I then use at.reshape to transform from aesara tensor column vector to a row vector.

# Map the column number to the time index
idx_dict = dict()
idx_dict[0] = sidx0
idx_dict[1] = sidx1

n_datasets = len(idx_dict.keys())
sigma_idx = np.concatenate([[i]* for i,sidx in idx_dict.items()], axis=None) # flattened np.array for study index
ln_amt_obs = np.concatenate([ln_amt0, ln_amt1], axis=None) # Flatted log-transformed amounts (observed)
with pm.Model() as model:
    # Population mean
    mu_lnk_abs = pm.Uniform('mu_lnk_abs', -5, 5)
    mu_lnk_elim = pm.Uniform('mu_lnk_elim', -5, 5)
    pm.Deterministic('mu_k_abs', pm.math.exp(mu_lnk_abs))
    pm.Deterministic('mu_k_elim', pm.math.exp(mu_lnk_elim))

    # Population sigma
    sigma_lnk_abs = pm.Exponential('sigma_lnk_abs', 1)
    sigma_lnk_elim = pm.Exponential('sigma_lnk_elim', 1)
    # Reparameterize for hierarchical sampling
    lnk_abs_offset = pm.Normal('lnk_abs_offset', mu=0, sigma=1, shape=(n_datasets,))
    lnk_abs = pm.Deterministic('lnk_abs', mu_lnk_abs + sigma_lnk_abs*lnk_abs_offset)
    lnk_elim_offset = pm.Normal('lnk_elim_offset', mu=0, sigma=1, shape=(n_datasets,))
    lnk_elim = pm.Deterministic('lnk_elim', mu_lnk_elim + sigma_lnk_elim*lnk_elim_offset)
    k_abs = pm.Deterministic('k_abs', pm.math.exp(lnk_abs))
    k_elim = pm.Deterministic('k_elim', pm.math.exp(lnk_elim))
    y_hat, _, problem, solver, _, _ = solve_ivp(
        y0 = {
            'A_expo': np.array([1., 1.]), # Initial dose for each study is 1 mg.
            'A_cent': np.array([0.,0.])
        params={'k_abs': (k_abs, (n_datasets,)),
                'k_elim': (k_elim, (n_datasets,)),
                '_dummy': (np.array(1.), ()),
        rhs = one_cmpt,
    A_cent = y_hat['A_cent']
    pm.Deterministic('A_cent', A_cent)
    A_data = at.reshape(at.concatenate([A_cent[sidx,i] for i,sidx in idx_dict.items()], axis=1), (len(sigma_idx),)) # Index the appropriate times for each study and flatten to a row vector
    A_cent_mu = pm.Deterministic('A_cent_mu', A_data)
    lnA_cent_mu = pm.Deterministic('lnA_cent_mu', pm.math.log(A_cent_mu))
    sd = pm.HalfNormal('sd', shape=(n_datasets,))
    lnA_obs = pm.Normal('lnA_obs', mu=lnA_cent_mu, sd=sd[sigma_idx], observed=ln_amt_obs)
    trace = pm.sample(tune=7000, draws=1000, target_accept=0.95)

This appears to do the trick and I’m able to predict the population parameters (mu_k_abs and mu_k_elim) pretty well.

with model:
    az.plot_posterior(trace, var_names=['mu_k_abs', 'mu_k_elim'], hdi_prob=0.9)

However, is there a better way to go about handling the Sunode output for this scenario? My questions are:

  1. Is flattening the output from Sunode and indexing the appropriate sd for each study the best method for handling observations across the different studies? If so, is there an optimal method for transforming the aesara tensor output from Sunode as opposed to the at.reshape method I used here?

  2. Otherwise, is there a better way to handle the Sunode output when dealing with studies with different sized observations to set up the hierarchical model?

Thanks a bunch for the help.

%watermark --iversions

matplotlib: 3.4.3
aesara    : 2.3.2
pymc      : 4.0.0b2
numpy     : 1.20.3
sunode    : 0.2.1
json      : 2.0.9

Hi @zult,
there are other ways to write the indexing, for example by stacking one iterator grabbing scalar, or with concatenate(..., axis=None) like you did in another line I think you don’t need the reshape.
Of course you can benchmark these alternatives on some toy tensor, but I’d expect this to be marginal compared to the ODE integration.

For your tvals you could place the timepoints at the superset instead of an np.arange. If that makes a speed or memory difference depends on your actual time data, of course.

An alternative setup would be independent forward-passes for each study instead of a single combined/broadcasted forward pass.
That’s currently what we do in
It also has the classes for handling data with variable time vectors and a BaseODEModel where you only need to bring your dydt to the party.
We integrate with sunode (if installed) and I’ll fix the PyMC 4.0.0b2/Aesara compatibility tomorrow.
In Bayesian calibration, process modeling and uncertainty quantification in biotechnology | bioRxiv we used it to build a hierarchical model with 28 replicates (that’d be 28 studies in your case) and rather sophisticated calibration models (observation functions) for the likelihood.
Code is here:

I would expect the independent forward passes to be faster in situations where the ODE solvers variable stepsize slows it down at certain dynamics, thereby potentially slowing down all otherwise independent replicates too. Also if there are a lot of replicates the jacobian could become very large (and sparse) such that it becomes costly.
Otherwise the solve_ivp approach you’re taking will probably be faster.

On a separate note, this single solve_ivp forward pass is definitely something we should add to murefi. (cc @lhelleckes)


Thank you so much for the response @michaelosthege. Quickly glancing through the paper, I think murefi is on the right track for what I’m after. The Monod ODE is going to be similar enough to compartmental pharmacokinetics so there’s a lot of overlap already. I’ll install murefi today and see if I can’t get it running for this example. Again, thanks for the insight.

Hi @zult, nice to hear!
We just released murefi v5.1.0 which is compatible with pymc3==3.11.4+Theano-PyMC+sunode or pymc==4.0.0b2+Aesara+sunode.
A similar compatibility fix was released with calibr8==6.3.0, so if you did a pip install before 6 PM UTC yesterday, you might want to update that one too.

Let me know how you like it, or if you encounter any problems!