SIR ODE model with Sunode

** I edit as I have found the problem, I wasn’t stating the “observed” variable in the likelihood (y), but it would still be nice to have some advice on how to improve sampling, as estimates are still slightly off. My apologies, I’ll try not to make a habit of asking for help before finding a solution. **

Hi all. I’m trying to implement a simple SIR model on PYMC, based on this Stan article: Bayesian workflow for disease transmission modeling in Stan
At the moment I’m trying to implement it using Sunode GitHub - pymc-devs/sunode: Solve ODEs fast, with support for PyMC
However, my current attempt completely fails, the model does not converge (ESSs are very low) even with 2000 tuning/samples and target_accept = 0.95, gives estimates of beta, gamma, R0, and recovery time that are quite slightly off.
I must be doing something wrong, as it takes more than 10 minutes to sample for just 14 timepoints (i.e. 14 days). (Especially as Sunode cannot use multicore for Windows).. I’m still unsure whether the model is optimal, but I failed to find any simple examples in PyMC v5 (rather than PyMC3). So, any help or advice will be greatly appreciated. Many thanks in advance.

# -*- coding: utf-8 -*-
import numpy as np
import pymc as pm
import pytensor.tensor as at
import pandas as pd
import sunode
import sunode.wrappers.as_pytensor
import arviz as az

#Data from a 14 days influenza outbreak in a British boarding school with 763 male students
df = pd.read_csv("influenza_england_1978_school.csv")

infected = df.in_bed #cases
n_days = len(infected)
Ntot = 763 #total number of students
times = np.arange(n_days)
t_0 = 0

#use Sunode ODE solver for SIR odes model, p=parameter, y=variable, t=the over variable
def SIR(t, y, p):
    return {
        #'N': y.N, #N is the number of succeptibles + infected + recovered (i.e. the whole population)
        'S': -(p.beta * y.S * y.I) / Ntot, #ODE for succeptible
        'I': ((p.beta * y.S * y.I) / Ntot) - p.gamma * y.I, #ODE for infected
        'R': p.gamma * y.I, #ODE for recovered
    }

with pm.Model() as mod:
    beta = pm.TruncatedNormal("beta", 2, 1, lower=0) #average number of contact per person per time parameter
    gamma = pm.TruncatedNormal("gamma", 0.4, 0.5, lower=0) #inverse of recovery time parameter 

    y_hat, _, problem, solver, _, _ = sunode.wrappers.as_pytensor.solve_ivp(
    y0={
        'S': (Ntot-1, ()), #starting point of I
        'I': (1, ()), #starting point of S
        'R': (0, ()), #starting point of R 
        #'N': (Ntot, ()), #starting point of N
        },
    params={
        'beta': (beta, ()),
        'gamma': (gamma, ()),
        'extra': np.zeros(1),
    },
    rhs=SIR,
    tvals=times,
    t0=0,
    )

    succ_estim = pm.Deterministic('Se', y_hat['S'])
    infec_estim = pm.Deterministic('Ie', y_hat['I'])
    recov_estim  = pm.Deterministic('Re', y_hat['R'])

    R0 = pm.Deterministic("R0", beta/gamma)

    recovery_time = pm.Deterministic("recovery time", 1/gamma)

    phi_inv = pm.Exponential("phi_inv", 1)

    y = pm.NegativeBinomial("y", mu=infec_estim, alpha=phi_inv, observed=infected)


with mod:
    idata = pm.sample(tune=2000, draws=2000, chains=4, cores=1, target_accept=0.95)

summ = az.summary(idata, var_names=['beta', 'gamma', 'R0', 'recovery time'])
summ

#pos = idata.stack(sample = ['chain', 'draw']).posterior
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (4 chains in 1 job)
NUTS: [beta, gamma, phi_inv]
 |█████████████| 100.00% [4000/4000 01:57<00:00 Sampling chain 3, 0 divergences]Sampling 4 chains for 2_000 tune and 2_000 draw iterations (8_000 + 8_000 draws total) took 418 seconds.

Out[1]: 
                mean     sd  hdi_3%  ...  ess_bulk  ess_tail  r_hat
beta           2.098  0.131   1.873  ...    4676.0    3123.0    1.0
gamma          0.510  0.061   0.391  ...    4667.0    4226.0    1.0
R0             4.187  0.662   3.117  ...    4308.0    3403.0    1.0
recovery time  1.990  0.245   1.562  ...    4667.0    4226.0    1.0

[4 rows x 9 columns]

[CVODEA ERROR]  CVodeF
  At t = 0.101077, mxstep steps taken before reaching tout.


[CVODEA ERROR]  CVodeF
  At t = 0.101077, mxstep steps taken before reaching tout.
type or paste code here

PS: I also don’t quite understand the errors given by Sunode. Any info would be great, many thanks again.

1 Like

Digging around I found this previous post: Bayesian Inference of an ODE SEIR model - #2 by jlindbloom . And the repo they recommend (GitHub - Priesemann-Group/covid19_inference: Bayesian python toolbox for inference and forecast of the spread of the Coronavirus) has a very nice solution for SIR/SEIR models. I implemented it for this small example, and it provides results which are closer to the Stan model.

# -*- coding: utf-8 -*-
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pandas as pd
import arviz as az
from pytensor import scan

#Data from a 14 days influenza outbreak in a British boarding school with 763 male students
df = pd.read_csv("influenza_england_1978_school.csv")

infected = df.in_bed*1.0 #cases
n_days = len(infected)
Ntot = 763 #total number of students
times = np.arange(n_days)
t_0 = 0
I_begin = infected[0]
S_begin = Ntot - infected[0]
new_I0 = pt.zeros_like(I_begin)

##This function is based on: https://github.com/Priesemann-Group/covid19_inference/blob/master/covid19_inference/model/compartmental_models/SIR.py
def sir_func(S_t, I_t, new_I0, beta, gamma, N):
    new_I_t = beta * S_t * I_t / N
    new_R_t = gamma * I_t 
    S_t = S_t - new_I_t 
    I_t = I_t + new_I_t - new_R_t   
    I_t = pt.clip(I_t, -1, N)  # for stability
    S_t = pt.clip(S_t, 0, N)
    return S_t, I_t, new_I_t


with pm.Model() as mod:
    # Priors
    beta = pm.Wald('beta', 1, 1)
    gamma = pm.Wald('gamma', 1, 1)
    
        # Variables
    S = pt.zeros(n_days)
    I = pt.zeros(n_days)
    R = pt.zeros(n_days)
    
    S = pt.subtensor.inc_subtensor(S[0], Ntot - infected[0])
    I = pt.subtensor.inc_subtensor(I[0], infected[0])

    outputs, _ = scan(
        fn=sir_func,
        outputs_info=[S_begin, I_begin, new_I0],
        non_sequences=[beta, gamma, Ntot],
        n_steps=n_days
    )
    S_t, I_t, new_I_t = outputs
    
    # Likelihood
    phi_inv = pm.Exponential("phi_inv", 1)
    
    y = pm.NegativeBinomial("y", mu=I_t, alpha=phi_inv, observed=infected)
    
    R0 = pm.Deterministic("R0", beta/gamma)
    recovery_time = pm.Deterministic("recovery time", 1/gamma)

with mod:
    idata = pm.sample(1000)
summ = az.summary(idata, hdi_prob=0.9)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, gamma, phi_inv]
 |████████████| 100.00% [8000/8000 00:16<00:00 Sampling 4 chains, 0 divergences]Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 36 seconds.

                mean     sd  hdi_5%  hdi_95%  ess_bulk  r_hat
beta           1.739  0.091   1.588    1.871    2077.0    1.0
gamma          0.483  0.052   0.398    0.565    2021.0    1.0
phi_inv        4.278  1.661   1.693    6.799    2417.0    1.0
R0             3.655  0.528   2.804    4.471    1769.0    1.0
recovery time  2.095  0.224   1.706    2.431    2021.0    1.0

I used Wald (inverse Gaussian) distributions instead of truncated Gaussians, which I find a bit smoother for sampling. The model samples much faster as well.

2 Likes