Digging around I found this previous post: Bayesian Inference of an ODE SEIR model - #2 by jlindbloom . And the repo they recommend (GitHub - Priesemann-Group/covid19_inference: Bayesian python toolbox for inference and forecast of the spread of the Coronavirus) has a very nice solution for SIR/SEIR models. I implemented it for this small example, and it provides results which are closer to the Stan model.
# -*- coding: utf-8 -*-
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pandas as pd
import arviz as az
from pytensor import scan
#Data from a 14 days influenza outbreak in a British boarding school with 763 male students
df = pd.read_csv("influenza_england_1978_school.csv")
infected = df.in_bed*1.0 #cases
n_days = len(infected)
Ntot = 763 #total number of students
times = np.arange(n_days)
t_0 = 0
I_begin = infected[0]
S_begin = Ntot - infected[0]
new_I0 = pt.zeros_like(I_begin)
##This function is based on: https://github.com/Priesemann-Group/covid19_inference/blob/master/covid19_inference/model/compartmental_models/SIR.py
def sir_func(S_t, I_t, new_I0, beta, gamma, N):
new_I_t = beta * S_t * I_t / N
new_R_t = gamma * I_t
S_t = S_t - new_I_t
I_t = I_t + new_I_t - new_R_t
I_t = pt.clip(I_t, -1, N) # for stability
S_t = pt.clip(S_t, 0, N)
return S_t, I_t, new_I_t
with pm.Model() as mod:
# Priors
beta = pm.Wald('beta', 1, 1)
gamma = pm.Wald('gamma', 1, 1)
# Variables
S = pt.zeros(n_days)
I = pt.zeros(n_days)
R = pt.zeros(n_days)
S = pt.subtensor.inc_subtensor(S[0], Ntot - infected[0])
I = pt.subtensor.inc_subtensor(I[0], infected[0])
outputs, _ = scan(
fn=sir_func,
outputs_info=[S_begin, I_begin, new_I0],
non_sequences=[beta, gamma, Ntot],
n_steps=n_days
)
S_t, I_t, new_I_t = outputs
# Likelihood
phi_inv = pm.Exponential("phi_inv", 1)
y = pm.NegativeBinomial("y", mu=I_t, alpha=phi_inv, observed=infected)
R0 = pm.Deterministic("R0", beta/gamma)
recovery_time = pm.Deterministic("recovery time", 1/gamma)
with mod:
idata = pm.sample(1000)
summ = az.summary(idata, hdi_prob=0.9)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta, gamma, phi_inv]
|████████████| 100.00% [8000/8000 00:16<00:00 Sampling 4 chains, 0 divergences]Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 36 seconds.
mean sd hdi_5% hdi_95% ess_bulk r_hat
beta 1.739 0.091 1.588 1.871 2077.0 1.0
gamma 0.483 0.052 0.398 0.565 2021.0 1.0
phi_inv 4.278 1.661 1.693 6.799 2417.0 1.0
R0 3.655 0.528 2.804 4.471 1769.0 1.0
recovery time 2.095 0.224 1.706 2.431 2021.0 1.0
I used Wald (inverse Gaussian) distributions instead of truncated Gaussians, which I find a bit smoother for sampling. The model samples much faster as well.