Huge difference between MCMC posterior samples and actual data - help needed!

Hi everyone,

I have been working on Bayesian inference models using SEIRS framework to analyze COVID-19 and influenza co-infection data. However, I have encountered a major problem. I generated some data using the seirs model and put it into this code to run, but the results were not satisfactory. After running MCMC sampling, I found that the posterior samples are very different from the actual observed data, by several orders of magnitude.

Here is my code :

import csv
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
import pymc as pm
import arviz as az
import sunode
import sunode.wrappers.as_pytensor
from scipy.interpolate import interp1d
from datetime import datetime, timedelta

class DiseaseDataVisualizer:
    def __init__(self, covid_file_path, influenza_file_path):
        self.covid_file_path = covid_file_path
        self.influenza_file_path = influenza_file_path
        self.covid_data = self.read_covid_data(self.covid_file_path)
        self.influenza_data = self.read_influenza_data(self.influenza_file_path)

    def read_covid_data(self, file_path):
        disease_data = {}

        with open(file_path, 'r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                region = row['WHO_region']
                date_reported_str = row['Date_reported']
                try:
                    date_reported = datetime.strptime(date_reported_str, '%Y/%m/%d')
                except:
                    date_reported = datetime.strptime(date_reported_str, '%Y-%m-%dT%H:%M:%S.%fZ')
                new_cases = int(row['New_cases']) if row['New_cases'] else 0

                if region not in disease_data:
                    disease_data[region] = {'dates': [], 'new_cases': []}

                if date_reported in disease_data[region]['dates']:
                    index = disease_data[region]['dates'].index(date_reported)
                    disease_data[region]['new_cases'][index] += new_cases
                else:
                    disease_data[region]['dates'].append(date_reported)
                    disease_data[region]['new_cases'].append(new_cases)

        return disease_data

    def read_influenza_data(self, file_path):
        disease_data = {}

        with open(file_path, 'r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                region = row['WHO region'] #+ 'O'
                try:
                    date_reported_str = row['ISO_SDATE']
                except:
                    date_reported_str = row['ISO_WEEKSTARTDATE']
                try:
                    date_reported = datetime.strptime(date_reported_str, '%Y/%m/%d')
                except:
                    date_reported = datetime.strptime(date_reported_str, '%Y-%m-%dT%H:%M:%S.%fZ')
                new_cases = int(row['INF_ALL']) if row['INF_ALL'] else 0

                if region not in disease_data:
                    disease_data[region] = {'dates': [], 'new_cases': []}

                if date_reported in disease_data[region]['dates']:
                    index = disease_data[region]['dates'].index(date_reported)
                    disease_data[region]['new_cases'][index] += new_cases
                else:
                    disease_data[region]['dates'].append(date_reported)
                    disease_data[region]['new_cases'].append(new_cases)

        return disease_data
    
    def interpolate_data(self, dates, cases):
        date_range = (dates[-1] - dates[0]).days
        all_dates = [dates[0] + timedelta(days=i) for i in range(date_range + 1)]
        
        # Convert dates to ordinal (numbers) for interpolation
        ordinal_dates = [date.toordinal() for date in dates]
        all_ordinal_dates = [date.toordinal() for date in all_dates]

        # Perform linear interpolation
        interp_func = interp1d(ordinal_dates, cases, kind='linear', fill_value="extrapolate")
        interpolated_cases = interp_func(all_ordinal_dates)
        
        return all_dates, interpolated_cases

    def get_region_data(self, region):
        covid_dates = self.covid_data[region]['dates']
        covid_cases = self.covid_data[region]['new_cases']
        influenza_dates = self.influenza_data[region]['dates']
        influenza_cases = self.influenza_data[region]['new_cases']

        interpolated_covid_dates, interpolated_covid_cases = self.interpolate_data(covid_dates, covid_cases)
        interpolated_influenza_dates, interpolated_influenza_cases = self.interpolate_data(influenza_dates, influenza_cases)

        data = {
            'covid_dates': interpolated_covid_dates,
            'covid_cases': interpolated_covid_cases,
            'influenza_dates': interpolated_influenza_dates,
            'influenza_cases': interpolated_influenza_cases
        }
        return data

class SERIS_model():
    def __init__(self,region,data) -> None:
        self.covid_dates = data['covid_dates']
        self.covid_cases = data['covid_cases']
        self.influenza_dates = data['influenza_dates']
        self.influenza_cases = data['influenza_cases']

        self.parameter = {
            'sigma_1' : 0.15,
            'sigma_2' : 0.5,
            'gamma_1' : 1/7,
            'gamma_2' : 0.2,
            'gamma_3' : 0.1,
            'theta1' : 0.001,
            'theta2' : 0.001,
            'xi' : 1/365 ,
        }
        self.samples_params = {
            'n_samples' : 100,
            'n_tune' : 20,
            'cores' : 12,
        }
        self.region_population = {
            'EURO' : 744000000,
            'AMRO' : 1018000000,
            'AFRO' : 1216000000,
            'SEARO' : 1984000000,
            'WPRO' : 1650000000,
            'EMRO' : 654000000,
        }
        self.region = region
        self.times = np.arange(0,max(len(self.covid_dates), len(self.influenza_dates)),1)

        self.init_y = [
            self.region_population[self.region] - self.covid_cases[0] / self.parameter['sigma_1'] - self.influenza_cases[0] / self.parameter['sigma_2'] - 0.38*self.region_population[self.region], #S0
            self.covid_cases[0] / self.parameter['sigma_1'],
            self.influenza_cases[0] / self.parameter['sigma_2'],
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0.38*self.region_population[self.region]
        ]
    def SEIRS_sunode(self, t, y, p):
        return {
            'S' : -p.beta1 * (y.S / p.n) * (y.I1 + y.I0) - p.beta2 * (y.S / p.n) * (y.I2 + y.I0) + p.xi * y.R,
            'E1' : p.beta1 * (y.S / p.n) * (y.I1 + y.I0) - p.sigma1 * y.E1 - p.theta1 * y.E1,
            'E2' : p.beta2 * (y.S / p.n) * (y.I2 + y.I0) - p.sigma2 * y.E2 - p.theta2 * y.E2,
            'E120' : p.theta2 * y.E1 - p.sigma1 * y.E120,
            'E210' : p.theta1 * y.E2 - p.sigma2 * y.E210,
            'I0' : p.sigma1 * y.E120 + p.sigma2 * y.E210 - p.gamma3 * y.I0,
            'I1' : p.sigma1 * y.E1 - p.gamma1 * y.I1,
            'I2' : p.sigma2 * y.E2 - p.gamma2 * y.I2,
            'S12' : p.gamma1 * y.I1 - p.gamma3 * (y.I12 + y.I2 + y.I0) * (y.S12 / p.n),
            'S21' : p.gamma2 * y.I2 - p.gamma3 * (y.I21 + y.I1 + y.I0) * (y.S21 / p.n),
            'E12' : p.delta1 * (y.I12 + y.I2 + y.I0) * (y.S12 / p.n) - p.sigma1 * y.E12,
            'E21' : p.delta2 * (y.I21 + y.I1 + y.I0) * (y.S21 / p.n) - p.sigma2 * y.E21,
            'I12' : p.sigma1 * y.E12 - p.gamma2 * y.I12,
            'I21' : p.sigma2 * y.E21 - p.gamma1 * y.I21,
            'R' : p.gamma2 * y.I12 + p.gamma1 * y.I21 - p.xi * y.R + p.gamma3 * y.I0
        }
    
    def run_SEIRS_model(self):
        with pm.Model() as self.model:
            # Priors for unknown model parameters
            beta1 = pm.LogNormal('beta1', np.log(0.5), 1)
            beta2 = pm.LogNormal('beta2', np.log(0.2), 1)
            delta1 = pm.LogNormal('delta1', np.log(0.5), 1)
            delta2 = pm.LogNormal('delta2', np.log(0.2), 1)
            sigma = pm.HalfNormal('sigma', 1, shape=1)
            self.report_rate = 1
            
            res, _, problem, solver, _, _ = sunode.wrappers.as_pytensor.solve_ivp(
            y0={
            # The initial conditions of the ode. Each variable
            # needs to specify a theano or numpy variable and a shape.
            # This dict can be nested.
                'S': (self.init_y[0], ()),
                'E1': (self.init_y[1], ()),
                'E2': (self.init_y[2], ()),
                'E120': (self.init_y[3], ()),
                'E210': (self.init_y[4], ()),
                'I0': (self.init_y[5], ()),
                'I1': (self.init_y[6], ()),
                'I2': (self.init_y[7], ()),
                'S12': (self.init_y[8], ()),
                'S21': (self.init_y[9], ()),
                'E12': (self.init_y[10], ()),
                'E21': (self.init_y[11], ()),
                'I12': (self.init_y[12], ()),
                'I21': (self.init_y[13], ()),
                'R': (self.init_y[14], ())},
            params={
            # Each parameter of the ode. sunode will only compute derivatives
            # with respect to theano variables. The shape needs to be specified
            # as well. It it infered automatically for numpy variables.
            # This dict can be nested.
                'beta1': (beta1, ()),
                'beta2': (beta2, ()),
                'delta1': (delta1, ()),
                'delta2': (delta2, ()),
                'theta1': (self.parameter['theta1'], ()),
                'theta2': (self.parameter['theta2'], ()),
                'sigma1': (self.parameter['sigma_1'], ()),
                'sigma2': (self.parameter['sigma_2'], ()),
                'gamma1': (self.parameter['gamma_1'], ()),
                'gamma2': (self.parameter['gamma_2'], ()),
                'gamma3': (self.parameter['gamma_3'], ()),
                'xi': (self.parameter['xi'], ()),
                'n': (self.region_population[self.region], ()),
                '_dummy': (np.array(1.), ()),
            },
            # A functions that computes the right-hand-side of the ode using
            # sympy variables.
            rhs=self.SEIRS_sunode,
            # The time points where we want to access the solution
            tvals=self.times,
            t0=self.times[0]
            )

            I1_all = self.parameter['sigma_1']*(res['E1'] + res['E21'] + res['E210'])
            I2_all = self.parameter['sigma_2']*(res['E2'] + res['E12'] + res['E120'])

            self.covid_relative_cases = self.report_rate * self.covid_cases
            self.influenza_relative_cases = self.report_rate * self.influenza_cases

            covid_obs = pm.StudentT('covid_obs',nu=10,  mu=I1_all, sigma=sigma, observed=self.covid_relative_cases)
            influenza_obs = pm.StudentT('influenza_obs',nu=10,  mu=I2_all, sigma=sigma, observed=self.influenza_relative_cases)

            # Run the MCMC
            trace = pm.sample(self.samples_params['n_samples'], tune=self.samples_params['n_tune'], cores=self.samples_params['cores'])
            trace.to_netcdf('./StudentT_trace_test.nc')
        return trace
            
            
        
    def analysis(self, trace):

        az.plot_trace(trace)
        
        print(az.summary(trace))
        az.plot_forest(trace, r_hat=True)
        az.plot_posterior(trace)

        with self.model:
            ppc_samples = pm.sample_posterior_predictive(trace,extend_inferencedata=True)

        covid_obs_check = ppc_samples.posterior_predictive['covid_obs']
        influenza_obs_check = ppc_samples.posterior_predictive['influenza_obs']

        covid_obs_mean = covid_obs_check.mean(axis=0)
        covid_obs_CriL = np.percentile(covid_obs_check, q=2.5, axis=0)
        covid_obs_CriU = np.percentile(covid_obs_check, q=97.5, axis=0)

        influenza_obs_mean = influenza_obs_check.mean(axis=0)
        influenza_obs_CriL = np.percentile(influenza_obs_check, q=2.5, axis=0)
        influenza_obs_CriU = np.percentile(influenza_obs_check, q=97.5, axis=0)


        plt.figure(figsize=(15, 2 * (5)))
        plt.subplot(2, 1, 1)
        plt.plot(self.covid_relative_cases, "o", color="r", lw=1, ms=10.5, label="Observed")
        for i in range(covid_obs_mean.shape[0]):
            plt.plot(self.times, covid_obs_mean[i, :], color="b", lw=0.3)


        plt.legend(fontsize=15)
        plt.xlabel("Days", fontsize=15)
        plt.ylabel("Covid", fontsize=15)

        plt.subplot(2, 1, 2)
        plt.plot(self.influenza_relative_cases, "o", color="b", lw=1, ms=10.5, label="Observed")
        for i in range(influenza_obs_mean.shape[0]):
            plt.plot(self.times, influenza_obs_mean[i, :], color="r", lw=0.3)

        plt.legend(fontsize=15)
        plt.xlabel("Days", fontsize=15)
        plt.ylabel("Influenza", fontsize=15)

        plt.show(block=True)
        print('done')


if __name__ == '__main__':
    region = 'EURO'
    data_visualizer = DiseaseDataVisualizer('COVID_Test.csv', 'ILI_test.csv')
    data = data_visualizer.get_region_data(region)

    model = SERIS_model(region,data)
    trace = model.run_SEIRS_model()
    model.analysis(trace)

My Test data :
COVID_Test.csv

WHO_region,Date_reported,New_cases
EURO,2024/8/1,0
EURO,2024/8/2,252316
EURO,2024/8/3,456938
EURO,2024/8/4,661231
EURO,2024/8/5,896038
EURO,2024/8/6,1183257
EURO,2024/8/7,1538148
EURO,2024/8/8,1967152
EURO,2024/8/9,2460277
EURO,2024/8/10,2983915
EURO,2024/8/11,3473403
EURO,2024/8/12,3846711
EURO,2024/8/13,4029418
EURO,2024/8/14,3996789
EURO,2024/8/15,3779914
EURO,2024/8/16,3455787
EURO,2024/8/17,3105232
EURO,2024/8/18,2790051
EURO,2024/8/19,2543361
EURO,2024/8/20,2373549
EURO,2024/8/21,2274405
EURO,2024/8/22,2230448
EURO,2024/8/23,2226769
EURO,2024/8/24,2246796
EURO,2024/8/25,2278226
EURO,2024/8/26,2311227
EURO,2024/8/27,2338624
EURO,2024/8/28,2355797
EURO,2024/8/29,2360350
EURO,2024/8/30,2351055
EURO,2024/8/31,2328012
EURO,2024/9/1,2292131
EURO,2024/9/2,2244510
EURO,2024/9/3,2186929
EURO,2024/9/4,2120982
EURO,2024/9/5,2048385
EURO,2024/9/6,1970854
EURO,2024/9/7,1889828

ILI_test.csv

WHO region,ISO_SDATE,INF_ALL
EURO,2024/8/1,0
EURO,2024/8/2,624345
EURO,2024/8/3,1172978
EURO,2024/8/4,1881159
EURO,2024/8/5,2906729
EURO,2024/8/6,4403334
EURO,2024/8/7,6537484
EURO,2024/8/8,9421065
EURO,2024/8/9,13078212
EURO,2024/8/10,17175683
EURO,2024/8/11,21134019
EURO,2024/8/12,23905744
EURO,2024/8/13,24794422
EURO,2024/8/14,23448470
EURO,2024/8/15,20533115
EURO,2024/8/16,16844322
EURO,2024/8/17,13324612
EURO,2024/8/18,10369412
EURO,2024/8/19,8161231
EURO,2024/8/20,6607051
EURO,2024/8/21,5572269
EURO,2024/8/22,4905979
EURO,2024/8/23,4471448
EURO,2024/8/24,4194026
EURO,2024/8/25,4002625
EURO,2024/8/26,3856308
EURO,2024/8/27,3735774
EURO,2024/8/28,3624106
EURO,2024/8/29,3515065
EURO,2024/8/30,3404724
EURO,2024/8/31,3291435
EURO,2024/9/1,3174570
EURO,2024/9/2,3056320
EURO,2024/9/3,2935733
EURO,2024/9/4,2815053
EURO,2024/9/5,2695352
EURO,2024/9/6,2576562
EURO,2024/9/7,2460901

Then the result is :


Is something goes wrong? The posterior samples value is around 10 while the actual data goes around 1000000.
image
Any insights would be greatly appreciated!

Thanks in advance!

The 1st recommendation would be to try and simplify the problem which already by itself might help you figure out what is going on as well as making it easier for other people to reproduce the issue and give advise.

Other than that, the things Iā€™d try would be:

Prior predictive checks, make sure your priors arenā€™t forcing the data to be in this lower scale. Potentially useful reference: Prior and Posterior Predictive Checks ā€” PyMC 5.16.2 documentation

Increase the number of tuning steps much more, the default is 1000 and it is more important to keep tuning iterations closer to this number than it is to keep the number of posterior samples generated. If you are indeed doing only 20 tuning steps what you are seeing are mostly random results which probably have nothing to do with the model.

Check there are no clear signs of the model not having converged. Things like rhat > 1.01, ess bulk < 400 mean you canā€™t really trust the outputs of the model

1 Like

Thanks for your advice! The NUTS samplers runs too slow, it always takes days to debug once, so I set the sample times small to give it a try. Today I changed it to Metropolis and increase the sample times.

#changed to [1] because both of [0] is 0
 self.init_y = [
            self.region_population[self.region] - self.covid_cases[1] / self.parameter['sigma_1'] - self.influenza_cases[1] / self.parameter['sigma_2'] - 0.38*self.region_population[self.region], #S0
            self.covid_cases[1] / self.parameter['sigma_1'],
            self.influenza_cases[1] / self.parameter['sigma_2'],
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0.38*self.region_population[self.region]
        ]
            # Run the MCMC 
            step = pm.Metropolis()
            trace = pm.sample(self.samples_params['n_samples'], tune=self.samples_params['n_tune'], cores=self.samples_params['cores'],step=step)

Finally I find the results better:

But the results are not very satisfactory either.I plan to continue to expand the number of n_tune and try.
In addition, I found that I seemed to have misunderstood the pm.sample. It seems that the number of tune samples is more important than formal sampling.
As long as the result fitting does not converge, I can try increasing the number of tunes, is that right?

Slow sampling and lack of convergence is a likely sign of a bad model. There are always simple models that you can fit / and you can reduce the dataset. Increasing tuning is unlikely to fix your problem in this case (at least thatā€™s my suspicion).

If you build your model incrementally/ expand slowly to the full dataset it will be easier to spot where things start to break and get advice here on the forum

2 Likes

First of all, thank you very much for your prompt reply!

The thing is my model is the SERIS compartment model of COVID and influenza co-infection, which is at the heart of my research, and I canā€™t change it. Also, after I increased the number of tunes, the results were indeed closer to what I expected. I thought maybe I should keep trying first.

        self.samples_params = {
            'n_samples' : 2000,
            'n_tune' : 50000,
            'cores' : 12,
        }

The result:


Thanks again for the suggestion! It was so important!

Let me be frank, if you need 50k tune and you dropped NUTS your model is crap/failing. Either you have a bug, data is problematic or the model is simply non-identifiable.

Metropolis will mask the problems, not overcome them.

My suggestion to start simpler is to find where the problem lies, not for you to settle on a simpler alternative.

3 Likes

Thank you for your sincere advice.
Referring to this document I rebuilt a simpler single-disease SIR model to test my code. Judging from the results, my model seems to have converged and NUTS runs obviously faster, but the test results of the posterior samples are not satisfying.
Here is the code:

import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import sunode
import sunode.wrappers.as_pytensor
import datetime
from datetime import datetime as dt
import time
import csv
from datetime import datetime
from scipy.interpolate import interp1d
from datetime import datetime, timedelta
import matplotlib.pyplot as plt

# -------- Usage --------#
# covid_obj = COVID_data('US', Population=328.2e6)
# covid_obj.get_dates(data_begin='7/11/20', data_end='7/20/20')
# sir_model = SIR_model(covid_obj)
# likelihood = {'distribution': 'lognormal', 'sigma': 2}
# prior= {'lam': 0.4, 'mu': 1/8, lambda_std', 0.5 'mu_std': 0.5 }
# sir_model.run_SIR_model(n_samples=20, n_tune=10, likelihood=likelihood)
np.random.seed(0)
 
class DiseaseDataVisualizer:
    def __init__(self, covid_file_path, influenza_file_path):
        self.covid_file_path = covid_file_path
        self.influenza_file_path = influenza_file_path
        self.covid_data = self.read_covid_data(self.covid_file_path)
        self.influenza_data = self.read_influenza_data(self.influenza_file_path)

    def read_covid_data(self, file_path):
        disease_data = {}

        with open(file_path, 'r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                region = row['WHO_region']
                date_reported_str = row['Date_reported']
                try:
                    date_reported = datetime.strptime(date_reported_str, '%Y/%m/%d')
                except:
                    date_reported = datetime.strptime(date_reported_str, '%Y-%m-%dT%H:%M:%S.%fZ')
                new_cases = int(row['New_cases']) if row['New_cases'] else 0

                if region not in disease_data:
                    disease_data[region] = {'dates': [], 'new_cases': []}

                if date_reported in disease_data[region]['dates']:
                    index = disease_data[region]['dates'].index(date_reported)
                    disease_data[region]['new_cases'][index] += new_cases
                else:
                    disease_data[region]['dates'].append(date_reported)
                    disease_data[region]['new_cases'].append(new_cases)

        return disease_data

    def read_influenza_data(self, file_path):
        disease_data = {}

        with open(file_path, 'r') as file:
            reader = csv.DictReader(file)
            for row in reader:
                region = row['WHO region'] #+ 'O'
                try:
                    date_reported_str = row['ISO_SDATE']
                except:
                    date_reported_str = row['ISO_WEEKSTARTDATE']
                try:
                    date_reported = datetime.strptime(date_reported_str, '%Y/%m/%d')
                except:
                    date_reported = datetime.strptime(date_reported_str, '%Y-%m-%dT%H:%M:%S.%fZ')
                new_cases = int(row['INF_ALL']) if row['INF_ALL'] else 0

                if region not in disease_data:
                    disease_data[region] = {'dates': [], 'new_cases': []}

                if date_reported in disease_data[region]['dates']:
                    index = disease_data[region]['dates'].index(date_reported)
                    disease_data[region]['new_cases'][index] += new_cases
                else:
                    disease_data[region]['dates'].append(date_reported)
                    disease_data[region]['new_cases'].append(new_cases)

        return disease_data
    
    def interpolate_data(self, dates, cases):
        date_range = (dates[-1] - dates[0]).days
        all_dates = [dates[0] + timedelta(days=i) for i in range(date_range + 1)]
        
        # Convert dates to ordinal (numbers) for interpolation
        ordinal_dates = [date.toordinal() for date in dates]
        all_ordinal_dates = [date.toordinal() for date in all_dates]

        # Perform linear interpolation
        interp_func = interp1d(ordinal_dates, cases, kind='linear', fill_value="extrapolate")
        interpolated_cases = interp_func(all_ordinal_dates)
        
        return all_dates, interpolated_cases

    def get_region_data(self, region):
        covid_dates = self.covid_data[region]['dates']
        covid_cases = self.covid_data[region]['new_cases']
        influenza_dates = self.influenza_data[region]['dates']
        influenza_cases = self.influenza_data[region]['new_cases']

        interpolated_covid_dates, interpolated_covid_cases = self.interpolate_data(covid_dates, covid_cases)
        interpolated_influenza_dates, interpolated_influenza_cases = self.interpolate_data(influenza_dates, influenza_cases)

        data = {
            'covid_dates': interpolated_covid_dates,
            'covid_cases': interpolated_covid_cases,
            'influenza_dates': interpolated_influenza_dates,
            'influenza_cases': interpolated_influenza_cases
        }
        return data
 
 
 
class SIR_model_sunode():
 
    def __init__(self, covid_data) :
 
        # ------------------------- Covid_data object -----------------------#
        self.covid_data = covid_data
        # ------------------------- Setup SIR model, but has to be called explicitly to run ------------------------#
        self.setup_SIR_model()
 
    def SIR_sunode(self, t, y, p):
        return {
            'S': -p.lam * y.S * y.I,
            'I': p.lam * y.S * y.I - p.mu * y.I,
        }
 
    def setup_SIR_model(self):
        self.N = 10000000

        self.time_range = np.arange(0,len(self.covid_data),1)
        self.I0 = self.covid_data[0]
        self.S0 = self.N - self.I0
        self.S_init = self.S0 / self.N
        self.I_init = self.I0 / self.N
        self.cases_obs_scaled = self.covid_data / self.N
 
 
    def run_SIR_model(self, n_samples, n_tune, likelihood, prior):
        # ------------------------- Metadata --------------------------------#
        now = dt.now()
        self.filename = 'Test'
        self.likelihood = likelihood
        self.n_samples = n_samples
        self.n_tune = n_tune
        self.likelihood = likelihood
        self.prior = prior
        # ------------------------ Write out metadata while the model is running -------------------#
 
        with pm.Model() as model4:
            sigma = pm.HalfCauchy('sigma', self.likelihood['sigma'], shape=1)
            lam_mu = np.log(self.prior['lam']) + self.prior['lambda_std']**2
            mu_mu = np.log(self.prior['mu']) + self.prior['mu_std']**2
            lam = pm.Lognormal('lambda', lam_mu , self.prior['lambda_std']) # 1.5, 1.5
            mu = pm.Lognormal('mu', mu_mu, self.prior['mu_std'])           # 1.5, 1.5
 
            res, _, problem, solver, _, _ = sunode.wrappers.as_pytensor.solve_ivp(
            y0={
            # The initial conditions of the ode. Each variable
            # needs to specify a theano or numpy variable and a shape.
            # This dict can be nested.
                'S': (self.S_init, ()),
                'I': (self.I_init, ()),},
            params={
            # Each parameter of the ode. sunode will only compute derivatives
            # with respect to theano variables. The shape needs to be specified
            # as well. It it infered automatically for numpy variables.
            # This dict can be nested.
                'lam': (lam, ()),
                'mu': (mu, ()),
                '_dummy': (np.array(1.), ())},
            # A functions that computes the right-hand-side of the ode using
            # sympy variables.
            rhs=self.SIR_sunode,
            # The time points where we want to access the solution
            tvals=self.time_range,
            t0=self.time_range[0]
            )
            if(likelihood['distribution'] == 'lognormal'):
                I = pm.Lognormal('I', mu=res['I'], sigma=sigma, observed=self.cases_obs_scaled)
            elif(likelihood['distribution'] == 'normal'):
                I = pm.Normal('I', mu=res['I'], sigma=sigma, observed=self.cases_obs_scaled)
            elif(likelihood['distribution'] == 'students-t'):
                I = pm.StudentT( "I",  nu=likelihood['nu'],       # likelihood distribution of the data
                        mu=res['I'],     # likelihood distribution mean, these are the predictions from SIR
                        sigma=sigma,
                        observed=self.cases_obs_scaled
                        )
 
            trace = pm.sample(self.n_samples, tune=self.n_tune, cores=1)
 
        az.plot_posterior(trace)
        az.plot_trace(trace)
        print(az.summary(trace))
        with model4:
            ppc_samples = pm.sample_posterior_predictive(trace)
        _, ax = plt.subplots()
        ax.set_xlim([-1,1])
        az.plot_ppc(ppc_samples, ax=ax)
        covid_obs_check = ppc_samples.posterior_predictive['I']

        covid_obs_mean = covid_obs_check.mean(axis=0)
        covid_obs_CriL = np.percentile(covid_obs_check, q=2.5, axis=0)
        covid_obs_CriU = np.percentile(covid_obs_check, q=97.5, axis=0)

        plt.figure(figsize=(15, 5))

        plt.plot(self.cases_obs_scaled, "o", color="r", lw=1, ms=1.5, label="Observed")
        for i in range(covid_obs_mean.shape[0]):
            plt.plot(self.time_range, covid_obs_mean[i, :], color="b", lw=0.3)
            #plt.plot(self.times, covid_obs_CriL[i, :], "--", color="r", lw=0.3)
            #plt.plot(self.times, covid_obs_CriU[i, :], "--", color="r", lw=0.3)


        plt.legend(fontsize=15, loc='upper left')
        plt.xlabel("Days", fontsize=15)
        plt.ylabel("Covid", fontsize=15)
        plt.savefig('StudentT_SIR.png')
        plt.ylim(0,1)        
        plt.savefig('StudentT_SIR0-1.png')
        plt.show(block=True)
        print('done')
if __name__ == '__main__':
    region = 'EURO'
    res = DiseaseDataVisualizer('COVID_Test.csv', 'ILI_test.csv')
    covid_obj = res.get_region_data(region)['covid_cases']
    sir_model = SIR_model_sunode(covid_obj)
    likelihood = {'distribution': 'lognormal', 'sigma': 2}
    prior= {'lam': 0.4, 'mu': 1/8, 'lambda_std' : 0.5 ,'mu_std': 0.5}
    sir_model.run_SIR_model(n_samples=200, n_tune=2000, likelihood=likelihood, prior=prior)

The result:

           mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
lambda    0.411  0.202   0.145    0.734      0.009    0.007     456.0     336.0   1.00
mu        0.235  0.113   0.051    0.448      0.006    0.004     399.0     287.0   1.00
sigma[0]  2.615  0.316   2.046    3.172      0.015    0.011     445.0     270.0   1.01

However the sample_posterior_predictive:


And when I enlarge this image:

Those observed spots were far below the curves, the trend of the actual data has not been reflected at all. Maybe there is a problem with my manual posterior distribution method?

Next I will continue to test a dual disease SIR model to see what went wrong. Thanks in advance!

Changed to a dual disease SIR model :

    def SIR_sunode(self, t, y, p):
        return {
            'S0': -p.lam1 * y.S0 * (y.I1 + y.I21)  - p.lam2 * y.S0 * (y.I2 + y.I12),
            'I1': -p.lam1 * y.S0 * (y.I1 + y.I21),
            'I2': -p.lam2 * y.S0 * (y.I2 + y.I12),
            'S12': p.mu1*y.I1 - p.lam21*y.S12*(y.I2 + y.I12),
            'S21': p.mu2*y.I2 - p.lam12*y.S21*(y.I1 + y.I21),
            'I12': p.lam21*y.S12*(y.I2 + y.I12) - p.mu2*y.I12,
            'I21': p.lam12*y.S21*(y.I1 + y.I21) - p.mu1*y.I21,
        }

Result:

           mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
sigma[0]  8.634  0.721   7.384   10.036      0.013    0.009    3274.0    1864.0   1.00
lambda1   0.567  0.294   0.133    1.050      0.005    0.004    3639.0    1999.0   1.00
lambda2   0.580  0.307   0.138    1.118      0.006    0.005    3132.0    1943.0   1.00
lambda12  0.584  0.313   0.142    1.146      0.006    0.005    4009.0    1975.0   1.01
lambda21  0.591  0.329   0.154    1.160      0.007    0.006    3802.0    1809.0   1.00
mu1       0.181  0.097   0.045    0.359      0.002    0.001    3748.0    1912.0   1.00
mu2       0.177  0.092   0.048    0.339      0.002    0.001    3965.0    1962.0   1.00

Seems great. But I want a clear fig of posterior distribution that fits the plot of the actual data. When I enlarge the figure, it doesnā€™t match the data trend at all.

If the posterior looks good, you could try to compute posterior predictive by hand to see if you would get the same ā€œstrangeā€ misfit as via PyMC

You need to be super gentle with the priors on LogNormal likelihood. The variance is \exp(\sigma^2 -1)\exp(2\mu + \sigma^2), so even at \mu=0 ,with your estimate of \sigma=8 youā€™re going to have a silly huge variance (which is indeed what you see with those spikes up to 1e15).

Iā€™ve recently started just modeling the log the data using a Normal because Iā€™ve found the LogNormal too unstable to work with. I was initially drawn to it because itā€™s more ā€œprincipledā€, but Iā€™ve just been bitten by outputs that look exactly like yours too many times.

3 Likes

One thing I could pick up from your posteriors is that the range of your Ī» values are very large for this type of ODE (between 0.1 and 1). Is that your force of infection term? You could try to use pymc.TruncatedNormal or other truncated distributions to get a sense of where your priors on them should be (and also to , artificially, reduce your variance). After you have a good set of priors, you might consider pymc.Potential to put weights around your observed values, and remove all the Truncated stuffs you used before.
From my own experience, the usual Bayesian statistics (divergences, ess, r_hat) are not reliable/ not good indicators whether your ODE model is sufficient if your posteriors are not anywhere near the observed values.

2 Likes

Thank you for your reply. Moreover, I found that the NUTS sampling speed was more than ten times faster after using the lognormal distribution. When I switched back to other distributions such as the normal distribution, the speed dropped significantly.

The function of the lognormal distribution seems to be to suppress the actual data into a straight line close to the X-axis, thereby ā€˜forcingā€™ the posterior distribution to conform to the actual data. But when I zoomed in on it, it didnā€™t go the same way at all. To be honest, Iā€™m a bit doubtful of its effectiveness :melting_face: