I’m experimenting with a certain stochastic model of disease. I have an array where each column represents a person and each row represents a time. The entries are integers corresponding to “states” the person might be in – for this disease they are “Susceptible” (0), “Latent” (1), “Infectious” (2) and “Recovered” (3). The leftmost column is an exception: its entries are floating point numbers representing timestamps.
At the moment I’m trying to fit three parameters to this data: gamma, zeta, and xi. Loosely speaking, gamma quantifies the transmission risk per infected person in the community, zeta represents the average length of the incubation (“Latent”) period, and xi represents the average length of the period of infectiousness.
So that you can follow along at home, I have uploaded fake data from a simulation where I know the underlying true values of the three parameters for this particular model. It should be attached to this post:
test_data.txt (2.5 MB)
Here’s the code to prepare it:
test_trajs = np.loadtxt(fname="test_data.txt", delimiter=",")
times = test_trajs[:,0]
## we need the length of time spent in each row
dts = np.diff(times)
## we need to compare "states now" to "states next"
state_history = test_trajs[:,1:].astype(np.int)
states_next = np.roll(state_history, shift=-1, axis=0)
## remove last entry to prevent off-by-one error
state_history = state_history[:-1,:]
states_next = states_next[:-1,:]
## 2D boolean arrays that appear in the likelihood
changes = state_history!=states_next
is_s = (state_history==0)
is_l = (state_history==1)
is_i = (state_history==2)
is_r = (state_history==3)
## 1D integer array for number of ppl in infectious state at each interval
n_infected = is_i.sum(axis=1)
Now, I have a function defined which takes in some proposed values of the parameters, and outputs a (negative) log-likelihood given the data. This is just the way the log-likelihood is defined for this kind of model. Unless I am mistaken, this function is always differentiable with respect to gamma, zeta, and xi so long as all three are positive. It’s numpy-vectorized so that it runs fast and avoids if-statements.
def negative_log_likelihood(x):
gamma, zeta, xi = x
qxus = ((is_s * n_infected[:,None] * gamma) +
(is_l * zeta) + (is_i * xi))
logls = np.log(qxus ** changes) - (qxus * dts[:, None])
logl = logls.sum()
return -logl
I can do a simple convex optimization on this to get maximum likelihood estimates of the parameters:
from scipy.optimize import minimize
bounds = [(0,np.inf), (0, np.inf), (0, np.inf)]
mle = minimize(negative_log_likelihood, x0=(1,1,1), bounds=bounds)
print(mle)
This outputs:
fun: 418.72800976337373
hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>
jac: array([-0.00002842, -0.00001137, -0.00001705])
message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
nfev: 44
nit: 10
status: 0
success: True
x: array([0.41231514, 0.85600939, 0.60539016])
Which is close to the correct values of the parameters I used to generate the data (0.4, 0.8, and 0.6). It gets similarly close answers when I generate another simulation with different true parameter values (details of the simulation function are omitted here for brevity). So, I believe the log-likelihood function implementation is correct.
Now, I want to use Pymc3 to get a Bayesian posterior for the three parameters. I thought I could simply rewrite the log-likelihood function using the pm.math
functions and then add it to the model via pm.Potential
:
with pm.Model() as the_model:
pm_gamma = pm.HalfNormal("pm_gamma", sd=1)
pm_zeta = pm.HalfNormal("pm_zeta", sd=1)
pm_xi = pm.HalfNormal("pm_xi", sd=1)
qxus = ((is_s * n_infected[:,None] * pm_gamma) +
(is_l * pm_zeta) + (is_i * pm_xi))
logls = pm.math.log(qxus ** changes) - (qxus * dts[:,None])
logl = pm.math.sum(logls)
potential = pm.Potential("potential", logl)
trace = pm.sample(draws=1000, chains=4, tune=2000)
pm.traceplot(trace)
print(pm.summary(trace))
This does start sampling for a while, but it eventually breaks with the following long error:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [pm_xi, pm_zeta, pm_gamma]
Sampling 4 chains: 3%|▎ | 396/12000 [00:01<00:51, 227.14draws/s]INFO (theano.gof.compilelock): Waiting for existing lock by process '29458' (I am process '29456')
INFO (theano.gof.compilelock): To manually release the lock, delete /home/cameron/.theano/compiledir_Linux-4.15--generic-x86_64-with-debian-buster-sid-x86_64-3.7.2-64/lock_dir
Sampling 4 chains: 3%|▎ | 412/12000 [00:08<04:08, 46.62draws/s]
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 82, in run
self._start_loop()
File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 123, in _start_loop
point, stats = self._compute_point()
File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 154, in _compute_point
point, stats = self._step_method.step(self._point)
File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/step_methods/arraystep.py", line 247, in step
apoint, stats = self.astep(array)
File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 135, in astep
self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/step_methods/hmc/quadpotential.py", line 231, in raise_ok
raise ValueError('\n'.join(errmsg))
ValueError: Mass matrix contains zeros on the diagonal.
The derivative of RV `pm_xi_log__`.ravel()[0] is zero.
The derivative of RV `pm_zeta_log__`.ravel()[0] is zero.
The derivative of RV `pm_gamma_log__`.ravel()[0] is zero.
"""
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
ValueError: Mass matrix contains zeros on the diagonal.
The derivative of RV `pm_xi_log__`.ravel()[0] is zero.
The derivative of RV `pm_zeta_log__`.ravel()[0] is zero.
The derivative of RV `pm_gamma_log__`.ravel()[0] is zero.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-12-9103245aae0f> in <module>
13 potential = pm.Potential("potential", logl)
14
---> 15 trace = pm.sample(draws=1000, chains=4, tune=2000)
16 pm.traceplot(trace)
17 print(pm.summary(trace))
~/anaconda3/lib/python3.7/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
437 _print_step_hierarchy(step)
438 try:
--> 439 trace = _mp_sample(**sample_args)
440 except pickle.PickleError:
441 _log.warning("Could not pickle model, sampling singlethreaded.")
~/anaconda3/lib/python3.7/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)
988 try:
989 with sampler:
--> 990 for draw in sampler:
991 trace = traces[draw.chain - chain]
992 if (trace.supports_sampler_stats
~/anaconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py in __iter__(self)
343
344 while self._active:
--> 345 draw = ProcessAdapter.recv_draw(self._active)
346 proc, is_last, draw, tuning, stats, warns = draw
347 if self._progress is not None:
~/anaconda3/lib/python3.7/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
247 else:
248 error = RuntimeError("Chain %s failed." % proc.chain)
--> 249 six.raise_from(error, old_error)
250 elif msg[0] == "writing_done":
251 proc._readable = True
~/anaconda3/lib/python3.7/site-packages/six.py in raise_from(value, from_value)
RuntimeError: Chain 1 failed.
I’m not sure why it should find a region with gradient of zero anywhere, given my log-likelihood function. Is there a way to have it record what values of the parameters give the bad gradients so I can troubleshoot this? Any other insights as to why Pymc3 is unhappy with this would be appreciated too.