Hi all. I’m trying to adapt the approach to estimate incubation periods from this paper: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4757028/#pone.0148506.ref002
R code and data here: Dryad | Data -- Association between the severity of influenza A(H7N9) virus infections and length of the incubation period .
However, I cannot get similar results to those reported in the paper, and I cannot figure out why. Here’s my code:
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import pytensor.tensor as pt
from scipy.special import gammainc as gammai
np.random.seed(27)
def weibull_cdf(x, kappa, theta):
return 1 - pt.exp(-(x/theta)**kappa)
def censored(name, kappa, theta, lower, upper):
L = weibull_cdf(lower, kappa, theta)
U = weibull_cdf(upper, kappa, theta)
return U - L
data = pd.read_csv("./original_data/data_h7n9_severity.csv")
lower = data.IncP_min.values #incubation periods lower boundary
upper = data.IncP_max.values #incubation periods upper boundary
fatal = data.death_status.values
with pm.Model() as mod:
# kappa = pm.HalfNormal('kappa', 5)
# theta = pm.HalfNormal('theta', 5)
kappa = pm.Uniform('kappa', 0, 100, shape=2) #as used in the paper
theta = pm.Uniform('theta', 0, 100, shape=2) #as used in the paper
mu = pm.Deterministic('mu', theta*pt.gamma(1 + 1 / kappa)) #incubation period mean
y = pm.Potential('y', censored('censored', kappa[fatal], theta[fatal], lower, upper))
with mod:
idata = pm.sample(2000, tune=2000, nuts_sampler='numpyro', random_seed=27)
az.summary(idata)
Out[27]:
mean sd hdi_3% hdi_97% ... mcse_sd ess_bulk ess_tail r_hat
kappa[0] 51.804 27.256 7.450 96.496 ... 0.643 831.0 917.0 1.01
kappa[1] 55.424 26.378 14.275 99.970 ... 0.574 969.0 829.0 1.00
theta[0] 3.569 1.385 1.252 6.081 ... 0.043 530.0 627.0 1.01
theta[1] 3.825 1.130 1.371 4.991 ... 0.048 423.0 681.0 1.01
f[0] 3.503 1.379 1.180 5.957 ... 0.042 546.0 738.0 1.01
f[1] 3.771 1.131 1.309 4.986 ... 0.046 427.0 721.0 1.01
mu[0] 3.508 1.363 1.260 5.991 ... 0.042 531.0 644.0 1.01
mu[1] 3.769 1.115 1.365 4.936 ... 0.047 420.0 698.0 1.01
[8 rows x 9 columns]
Any help will be really appreciated.