SIR ODE model with Sunode

Digging around I found this previous post: Bayesian Inference of an ODE SEIR model - #2 by jlindbloom . And the repo they recommend (GitHub - Priesemann-Group/covid19_inference: Bayesian python toolbox for inference and forecast of the spread of the Coronavirus) has a very nice solution for SIR/SEIR models. I implemented it for this small example, and it provides results which are closer to the Stan model.

# -*- coding: utf-8 -*-
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pandas as pd
import arviz as az
from pytensor import scan

#Data from a 14 days influenza outbreak in a British boarding school with 763 male students
df = pd.read_csv("influenza_england_1978_school.csv")

infected = df.in_bed*1.0 #cases
n_days = len(infected)
Ntot = 763 #total number of students
times = np.arange(n_days)
t_0 = 0
I_begin = infected[0]
S_begin = Ntot - infected[0]
new_I0 = pt.zeros_like(I_begin)

##This function is based on: https://github.com/Priesemann-Group/covid19_inference/blob/master/covid19_inference/model/compartmental_models/SIR.py
def sir_func(S_t, I_t, new_I0, beta, gamma, N):
    new_I_t = beta * S_t * I_t / N
    new_R_t = gamma * I_t 
    S_t = S_t - new_I_t 
    I_t = I_t + new_I_t - new_R_t   
    I_t = pt.clip(I_t, -1, N)  # for stability
    S_t = pt.clip(S_t, 0, N)
    return S_t, I_t, new_I_t


with pm.Model() as mod:
    # Priors
    beta = pm.Wald('beta', 1, 1)
    gamma = pm.Wald('gamma', 1, 1)
    
        # Variables
    S = pt.zeros(n_days)
    I = pt.zeros(n_days)
    R = pt.zeros(n_days)
    
    S = pt.subtensor.inc_subtensor(S[0], Ntot - infected[0])
    I = pt.subtensor.inc_subtensor(I[0], infected[0])

    outputs, _ = scan(
        fn=sir_func,
        outputs_info=[S_begin, I_begin, new_I0],
        non_sequences=[beta, gamma, Ntot],
        n_steps=n_days
    )
    S_t, I_t, new_I_t = outputs
    
    # Likelihood
    phi_inv = pm.Exponential("phi_inv", 1)
    
    y = pm.NegativeBinomial("y", mu=I_t, alpha=phi_inv, observed=infected)
    
    R0 = pm.Deterministic("R0", beta/gamma)
    recovery_time = pm.Deterministic("recovery time", 1/gamma)

with mod:
    idata = pm.sample(1000)
summ = az.summary(idata, hdi_prob=0.9)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, gamma, phi_inv]
 |████████████| 100.00% [8000/8000 00:16<00:00 Sampling 4 chains, 0 divergences]Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 36 seconds.

                mean     sd  hdi_5%  hdi_95%  ess_bulk  r_hat
beta           1.739  0.091   1.588    1.871    2077.0    1.0
gamma          0.483  0.052   0.398    0.565    2021.0    1.0
phi_inv        4.278  1.661   1.693    6.799    2417.0    1.0
R0             3.655  0.528   2.804    4.471    1769.0    1.0
recovery time  2.095  0.224   1.706    2.431    2021.0    1.0

I used Wald (inverse Gaussian) distributions instead of truncated Gaussians, which I find a bit smoother for sampling. The model samples much faster as well.

4 Likes