Hi everyone,
I have been working on Bayesian inference models using SEIRS framework to analyze COVID-19 and influenza co-infection data. However, I have encountered a major problem. I generated some data using the seirs model and put it into this code to run, but the results were not satisfactory. After running MCMC sampling, I found that the posterior samples are very different from the actual observed data, by several orders of magnitude.
Here is my code :
import csv
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
import pymc as pm
import arviz as az
import sunode
import sunode.wrappers.as_pytensor
from scipy.interpolate import interp1d
from datetime import datetime, timedelta
class DiseaseDataVisualizer:
def __init__(self, covid_file_path, influenza_file_path):
self.covid_file_path = covid_file_path
self.influenza_file_path = influenza_file_path
self.covid_data = self.read_covid_data(self.covid_file_path)
self.influenza_data = self.read_influenza_data(self.influenza_file_path)
def read_covid_data(self, file_path):
disease_data = {}
with open(file_path, 'r') as file:
reader = csv.DictReader(file)
for row in reader:
region = row['WHO_region']
date_reported_str = row['Date_reported']
try:
date_reported = datetime.strptime(date_reported_str, '%Y/%m/%d')
except:
date_reported = datetime.strptime(date_reported_str, '%Y-%m-%dT%H:%M:%S.%fZ')
new_cases = int(row['New_cases']) if row['New_cases'] else 0
if region not in disease_data:
disease_data[region] = {'dates': [], 'new_cases': []}
if date_reported in disease_data[region]['dates']:
index = disease_data[region]['dates'].index(date_reported)
disease_data[region]['new_cases'][index] += new_cases
else:
disease_data[region]['dates'].append(date_reported)
disease_data[region]['new_cases'].append(new_cases)
return disease_data
def read_influenza_data(self, file_path):
disease_data = {}
with open(file_path, 'r') as file:
reader = csv.DictReader(file)
for row in reader:
region = row['WHO region'] #+ 'O'
try:
date_reported_str = row['ISO_SDATE']
except:
date_reported_str = row['ISO_WEEKSTARTDATE']
try:
date_reported = datetime.strptime(date_reported_str, '%Y/%m/%d')
except:
date_reported = datetime.strptime(date_reported_str, '%Y-%m-%dT%H:%M:%S.%fZ')
new_cases = int(row['INF_ALL']) if row['INF_ALL'] else 0
if region not in disease_data:
disease_data[region] = {'dates': [], 'new_cases': []}
if date_reported in disease_data[region]['dates']:
index = disease_data[region]['dates'].index(date_reported)
disease_data[region]['new_cases'][index] += new_cases
else:
disease_data[region]['dates'].append(date_reported)
disease_data[region]['new_cases'].append(new_cases)
return disease_data
def interpolate_data(self, dates, cases):
date_range = (dates[-1] - dates[0]).days
all_dates = [dates[0] + timedelta(days=i) for i in range(date_range + 1)]
# Convert dates to ordinal (numbers) for interpolation
ordinal_dates = [date.toordinal() for date in dates]
all_ordinal_dates = [date.toordinal() for date in all_dates]
# Perform linear interpolation
interp_func = interp1d(ordinal_dates, cases, kind='linear', fill_value="extrapolate")
interpolated_cases = interp_func(all_ordinal_dates)
return all_dates, interpolated_cases
def get_region_data(self, region):
covid_dates = self.covid_data[region]['dates']
covid_cases = self.covid_data[region]['new_cases']
influenza_dates = self.influenza_data[region]['dates']
influenza_cases = self.influenza_data[region]['new_cases']
interpolated_covid_dates, interpolated_covid_cases = self.interpolate_data(covid_dates, covid_cases)
interpolated_influenza_dates, interpolated_influenza_cases = self.interpolate_data(influenza_dates, influenza_cases)
data = {
'covid_dates': interpolated_covid_dates,
'covid_cases': interpolated_covid_cases,
'influenza_dates': interpolated_influenza_dates,
'influenza_cases': interpolated_influenza_cases
}
return data
class SERIS_model():
def __init__(self,region,data) -> None:
self.covid_dates = data['covid_dates']
self.covid_cases = data['covid_cases']
self.influenza_dates = data['influenza_dates']
self.influenza_cases = data['influenza_cases']
self.parameter = {
'sigma_1' : 0.15,
'sigma_2' : 0.5,
'gamma_1' : 1/7,
'gamma_2' : 0.2,
'gamma_3' : 0.1,
'theta1' : 0.001,
'theta2' : 0.001,
'xi' : 1/365 ,
}
self.samples_params = {
'n_samples' : 100,
'n_tune' : 20,
'cores' : 12,
}
self.region_population = {
'EURO' : 744000000,
'AMRO' : 1018000000,
'AFRO' : 1216000000,
'SEARO' : 1984000000,
'WPRO' : 1650000000,
'EMRO' : 654000000,
}
self.region = region
self.times = np.arange(0,max(len(self.covid_dates), len(self.influenza_dates)),1)
self.init_y = [
self.region_population[self.region] - self.covid_cases[0] / self.parameter['sigma_1'] - self.influenza_cases[0] / self.parameter['sigma_2'] - 0.38*self.region_population[self.region], #S0
self.covid_cases[0] / self.parameter['sigma_1'],
self.influenza_cases[0] / self.parameter['sigma_2'],
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0.38*self.region_population[self.region]
]
def SEIRS_sunode(self, t, y, p):
return {
'S' : -p.beta1 * (y.S / p.n) * (y.I1 + y.I0) - p.beta2 * (y.S / p.n) * (y.I2 + y.I0) + p.xi * y.R,
'E1' : p.beta1 * (y.S / p.n) * (y.I1 + y.I0) - p.sigma1 * y.E1 - p.theta1 * y.E1,
'E2' : p.beta2 * (y.S / p.n) * (y.I2 + y.I0) - p.sigma2 * y.E2 - p.theta2 * y.E2,
'E120' : p.theta2 * y.E1 - p.sigma1 * y.E120,
'E210' : p.theta1 * y.E2 - p.sigma2 * y.E210,
'I0' : p.sigma1 * y.E120 + p.sigma2 * y.E210 - p.gamma3 * y.I0,
'I1' : p.sigma1 * y.E1 - p.gamma1 * y.I1,
'I2' : p.sigma2 * y.E2 - p.gamma2 * y.I2,
'S12' : p.gamma1 * y.I1 - p.gamma3 * (y.I12 + y.I2 + y.I0) * (y.S12 / p.n),
'S21' : p.gamma2 * y.I2 - p.gamma3 * (y.I21 + y.I1 + y.I0) * (y.S21 / p.n),
'E12' : p.delta1 * (y.I12 + y.I2 + y.I0) * (y.S12 / p.n) - p.sigma1 * y.E12,
'E21' : p.delta2 * (y.I21 + y.I1 + y.I0) * (y.S21 / p.n) - p.sigma2 * y.E21,
'I12' : p.sigma1 * y.E12 - p.gamma2 * y.I12,
'I21' : p.sigma2 * y.E21 - p.gamma1 * y.I21,
'R' : p.gamma2 * y.I12 + p.gamma1 * y.I21 - p.xi * y.R + p.gamma3 * y.I0
}
def run_SEIRS_model(self):
with pm.Model() as self.model:
# Priors for unknown model parameters
beta1 = pm.LogNormal('beta1', np.log(0.5), 1)
beta2 = pm.LogNormal('beta2', np.log(0.2), 1)
delta1 = pm.LogNormal('delta1', np.log(0.5), 1)
delta2 = pm.LogNormal('delta2', np.log(0.2), 1)
sigma = pm.HalfNormal('sigma', 1, shape=1)
self.report_rate = 1
res, _, problem, solver, _, _ = sunode.wrappers.as_pytensor.solve_ivp(
y0={
# The initial conditions of the ode. Each variable
# needs to specify a theano or numpy variable and a shape.
# This dict can be nested.
'S': (self.init_y[0], ()),
'E1': (self.init_y[1], ()),
'E2': (self.init_y[2], ()),
'E120': (self.init_y[3], ()),
'E210': (self.init_y[4], ()),
'I0': (self.init_y[5], ()),
'I1': (self.init_y[6], ()),
'I2': (self.init_y[7], ()),
'S12': (self.init_y[8], ()),
'S21': (self.init_y[9], ()),
'E12': (self.init_y[10], ()),
'E21': (self.init_y[11], ()),
'I12': (self.init_y[12], ()),
'I21': (self.init_y[13], ()),
'R': (self.init_y[14], ())},
params={
# Each parameter of the ode. sunode will only compute derivatives
# with respect to theano variables. The shape needs to be specified
# as well. It it infered automatically for numpy variables.
# This dict can be nested.
'beta1': (beta1, ()),
'beta2': (beta2, ()),
'delta1': (delta1, ()),
'delta2': (delta2, ()),
'theta1': (self.parameter['theta1'], ()),
'theta2': (self.parameter['theta2'], ()),
'sigma1': (self.parameter['sigma_1'], ()),
'sigma2': (self.parameter['sigma_2'], ()),
'gamma1': (self.parameter['gamma_1'], ()),
'gamma2': (self.parameter['gamma_2'], ()),
'gamma3': (self.parameter['gamma_3'], ()),
'xi': (self.parameter['xi'], ()),
'n': (self.region_population[self.region], ()),
'_dummy': (np.array(1.), ()),
},
# A functions that computes the right-hand-side of the ode using
# sympy variables.
rhs=self.SEIRS_sunode,
# The time points where we want to access the solution
tvals=self.times,
t0=self.times[0]
)
I1_all = self.parameter['sigma_1']*(res['E1'] + res['E21'] + res['E210'])
I2_all = self.parameter['sigma_2']*(res['E2'] + res['E12'] + res['E120'])
self.covid_relative_cases = self.report_rate * self.covid_cases
self.influenza_relative_cases = self.report_rate * self.influenza_cases
covid_obs = pm.StudentT('covid_obs',nu=10, mu=I1_all, sigma=sigma, observed=self.covid_relative_cases)
influenza_obs = pm.StudentT('influenza_obs',nu=10, mu=I2_all, sigma=sigma, observed=self.influenza_relative_cases)
# Run the MCMC
trace = pm.sample(self.samples_params['n_samples'], tune=self.samples_params['n_tune'], cores=self.samples_params['cores'])
trace.to_netcdf('./StudentT_trace_test.nc')
return trace
def analysis(self, trace):
az.plot_trace(trace)
print(az.summary(trace))
az.plot_forest(trace, r_hat=True)
az.plot_posterior(trace)
with self.model:
ppc_samples = pm.sample_posterior_predictive(trace,extend_inferencedata=True)
covid_obs_check = ppc_samples.posterior_predictive['covid_obs']
influenza_obs_check = ppc_samples.posterior_predictive['influenza_obs']
covid_obs_mean = covid_obs_check.mean(axis=0)
covid_obs_CriL = np.percentile(covid_obs_check, q=2.5, axis=0)
covid_obs_CriU = np.percentile(covid_obs_check, q=97.5, axis=0)
influenza_obs_mean = influenza_obs_check.mean(axis=0)
influenza_obs_CriL = np.percentile(influenza_obs_check, q=2.5, axis=0)
influenza_obs_CriU = np.percentile(influenza_obs_check, q=97.5, axis=0)
plt.figure(figsize=(15, 2 * (5)))
plt.subplot(2, 1, 1)
plt.plot(self.covid_relative_cases, "o", color="r", lw=1, ms=10.5, label="Observed")
for i in range(covid_obs_mean.shape[0]):
plt.plot(self.times, covid_obs_mean[i, :], color="b", lw=0.3)
plt.legend(fontsize=15)
plt.xlabel("Days", fontsize=15)
plt.ylabel("Covid", fontsize=15)
plt.subplot(2, 1, 2)
plt.plot(self.influenza_relative_cases, "o", color="b", lw=1, ms=10.5, label="Observed")
for i in range(influenza_obs_mean.shape[0]):
plt.plot(self.times, influenza_obs_mean[i, :], color="r", lw=0.3)
plt.legend(fontsize=15)
plt.xlabel("Days", fontsize=15)
plt.ylabel("Influenza", fontsize=15)
plt.show(block=True)
print('done')
if __name__ == '__main__':
region = 'EURO'
data_visualizer = DiseaseDataVisualizer('COVID_Test.csv', 'ILI_test.csv')
data = data_visualizer.get_region_data(region)
model = SERIS_model(region,data)
trace = model.run_SEIRS_model()
model.analysis(trace)
My Test data :
COVID_Test.csv
WHO_region,Date_reported,New_cases
EURO,2024/8/1,0
EURO,2024/8/2,252316
EURO,2024/8/3,456938
EURO,2024/8/4,661231
EURO,2024/8/5,896038
EURO,2024/8/6,1183257
EURO,2024/8/7,1538148
EURO,2024/8/8,1967152
EURO,2024/8/9,2460277
EURO,2024/8/10,2983915
EURO,2024/8/11,3473403
EURO,2024/8/12,3846711
EURO,2024/8/13,4029418
EURO,2024/8/14,3996789
EURO,2024/8/15,3779914
EURO,2024/8/16,3455787
EURO,2024/8/17,3105232
EURO,2024/8/18,2790051
EURO,2024/8/19,2543361
EURO,2024/8/20,2373549
EURO,2024/8/21,2274405
EURO,2024/8/22,2230448
EURO,2024/8/23,2226769
EURO,2024/8/24,2246796
EURO,2024/8/25,2278226
EURO,2024/8/26,2311227
EURO,2024/8/27,2338624
EURO,2024/8/28,2355797
EURO,2024/8/29,2360350
EURO,2024/8/30,2351055
EURO,2024/8/31,2328012
EURO,2024/9/1,2292131
EURO,2024/9/2,2244510
EURO,2024/9/3,2186929
EURO,2024/9/4,2120982
EURO,2024/9/5,2048385
EURO,2024/9/6,1970854
EURO,2024/9/7,1889828
ILI_test.csv
WHO region,ISO_SDATE,INF_ALL
EURO,2024/8/1,0
EURO,2024/8/2,624345
EURO,2024/8/3,1172978
EURO,2024/8/4,1881159
EURO,2024/8/5,2906729
EURO,2024/8/6,4403334
EURO,2024/8/7,6537484
EURO,2024/8/8,9421065
EURO,2024/8/9,13078212
EURO,2024/8/10,17175683
EURO,2024/8/11,21134019
EURO,2024/8/12,23905744
EURO,2024/8/13,24794422
EURO,2024/8/14,23448470
EURO,2024/8/15,20533115
EURO,2024/8/16,16844322
EURO,2024/8/17,13324612
EURO,2024/8/18,10369412
EURO,2024/8/19,8161231
EURO,2024/8/20,6607051
EURO,2024/8/21,5572269
EURO,2024/8/22,4905979
EURO,2024/8/23,4471448
EURO,2024/8/24,4194026
EURO,2024/8/25,4002625
EURO,2024/8/26,3856308
EURO,2024/8/27,3735774
EURO,2024/8/28,3624106
EURO,2024/8/29,3515065
EURO,2024/8/30,3404724
EURO,2024/8/31,3291435
EURO,2024/9/1,3174570
EURO,2024/9/2,3056320
EURO,2024/9/3,2935733
EURO,2024/9/4,2815053
EURO,2024/9/5,2695352
EURO,2024/9/6,2576562
EURO,2024/9/7,2460901
Then the result is :
Is something goes wrong? The posterior samples value is around 10 while the actual data goes around 1000000.
Any insights would be greatly appreciated!
Thanks in advance!