Sampling crashes with "The derivative of RV [whatever] is zero" when using a custom likelihood defined with pm.Potential

I’m experimenting with a certain stochastic model of disease. I have an array where each column represents a person and each row represents a time. The entries are integers corresponding to “states” the person might be in – for this disease they are “Susceptible” (0), “Latent” (1), “Infectious” (2) and “Recovered” (3). The leftmost column is an exception: its entries are floating point numbers representing timestamps.

At the moment I’m trying to fit three parameters to this data: gamma, zeta, and xi. Loosely speaking, gamma quantifies the transmission risk per infected person in the community, zeta represents the average length of the incubation (“Latent”) period, and xi represents the average length of the period of infectiousness.

So that you can follow along at home, I have uploaded fake data from a simulation where I know the underlying true values of the three parameters for this particular model. It should be attached to this post:

test_data.txt (2.5 MB)

Here’s the code to prepare it:

test_trajs = np.loadtxt(fname="test_data.txt", delimiter=",")

times = test_trajs[:,0]
## we need the length of time spent in each row
dts = np.diff(times)

## we need to compare "states now" to "states next"
state_history = test_trajs[:,1:].astype(
states_next = np.roll(state_history, shift=-1, axis=0)

## remove last entry to prevent off-by-one error
state_history = state_history[:-1,:]
states_next = states_next[:-1,:]

## 2D boolean arrays that appear in the likelihood
changes = state_history!=states_next
is_s = (state_history==0)
is_l = (state_history==1)
is_i = (state_history==2)
is_r = (state_history==3)
## 1D integer array for number of ppl in infectious state at each interval
n_infected = is_i.sum(axis=1)

Now, I have a function defined which takes in some proposed values of the parameters, and outputs a (negative) log-likelihood given the data. This is just the way the log-likelihood is defined for this kind of model. Unless I am mistaken, this function is always differentiable with respect to gamma, zeta, and xi so long as all three are positive. It’s numpy-vectorized so that it runs fast and avoids if-statements.

def negative_log_likelihood(x):
    gamma, zeta, xi = x
    qxus = ((is_s * n_infected[:,None] * gamma) +
            (is_l * zeta) + (is_i * xi))
    logls = np.log(qxus ** changes) - (qxus * dts[:, None])
    logl = logls.sum()

    return -logl

I can do a simple convex optimization on this to get maximum likelihood estimates of the parameters:

from scipy.optimize import minimize
bounds = [(0,np.inf), (0, np.inf), (0, np.inf)]

mle = minimize(negative_log_likelihood, x0=(1,1,1), bounds=bounds)

This outputs:

      fun: 418.72800976337373
 hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>
      jac: array([-0.00002842, -0.00001137, -0.00001705])
     nfev: 44
      nit: 10
   status: 0
  success: True
        x: array([0.41231514, 0.85600939, 0.60539016])

Which is close to the correct values of the parameters I used to generate the data (0.4, 0.8, and 0.6). It gets similarly close answers when I generate another simulation with different true parameter values (details of the simulation function are omitted here for brevity). So, I believe the log-likelihood function implementation is correct.

Now, I want to use Pymc3 to get a Bayesian posterior for the three parameters. I thought I could simply rewrite the log-likelihood function using the pm.math functions and then add it to the model via pm.Potential:

with pm.Model() as the_model:
    pm_gamma = pm.HalfNormal("pm_gamma", sd=1)
    pm_zeta = pm.HalfNormal("pm_zeta", sd=1)
    pm_xi = pm.HalfNormal("pm_xi", sd=1)

    qxus = ((is_s * n_infected[:,None] * pm_gamma) +
            (is_l * pm_zeta) + (is_i * pm_xi))

    logls = pm.math.log(qxus ** changes) - (qxus * dts[:,None])
    logl = pm.math.sum(logls)

    potential = pm.Potential("potential", logl)

    trace = pm.sample(draws=1000, chains=4, tune=2000)

This does start sampling for a while, but it eventually breaks with the following long error:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [pm_xi, pm_zeta, pm_gamma]
Sampling 4 chains:   3%|▎         | 396/12000 [00:01<00:51, 227.14draws/s]INFO (theano.gof.compilelock): Waiting for existing lock by process '29458' (I am process '29456')
INFO (theano.gof.compilelock): To manually release the lock, delete /home/cameron/.theano/compiledir_Linux-4.15--generic-x86_64-with-debian-buster-sid-x86_64-3.7.2-64/lock_dir
Sampling 4 chains:   3%|▎         | 412/12000 [00:08<04:08, 46.62draws/s] 
RemoteTraceback                           Traceback (most recent call last)
Traceback (most recent call last):
  File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/", line 82, in run
  File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/", line 123, in _start_loop
    point, stats = self._compute_point()
  File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/", line 154, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/step_methods/", line 247, in step
    apoint, stats = self.astep(array)
  File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/step_methods/hmc/", line 135, in astep
  File "/home/cameron/anaconda3/lib/python3.7/site-packages/pymc3/step_methods/hmc/", line 231, in raise_ok
    raise ValueError('\n'.join(errmsg))
ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `pm_xi_log__`.ravel()[0] is zero.
The derivative of RV `pm_zeta_log__`.ravel()[0] is zero.
The derivative of RV `pm_gamma_log__`.ravel()[0] is zero.

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `pm_xi_log__`.ravel()[0] is zero.
The derivative of RV `pm_zeta_log__`.ravel()[0] is zero.
The derivative of RV `pm_gamma_log__`.ravel()[0] is zero.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-12-9103245aae0f> in <module>
     13     potential = pm.Potential("potential", logl)
---> 15     trace = pm.sample(draws=1000, chains=4, tune=2000)
     16     pm.traceplot(trace)
     17     print(pm.summary(trace))

~/anaconda3/lib/python3.7/site-packages/pymc3/ in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    437             _print_step_hierarchy(step)
    438             try:
--> 439                 trace = _mp_sample(**sample_args)
    440             except pickle.PickleError:
    441                 _log.warning("Could not pickle model, sampling singlethreaded.")

~/anaconda3/lib/python3.7/site-packages/pymc3/ in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)
    988             try:
    989                 with sampler:
--> 990                     for draw in sampler:
    991                         trace = traces[draw.chain - chain]
    992                         if (trace.supports_sampler_stats

~/anaconda3/lib/python3.7/site-packages/pymc3/ in __iter__(self)
    344         while self._active:
--> 345             draw = ProcessAdapter.recv_draw(self._active)
    346             proc, is_last, draw, tuning, stats, warns = draw
    347             if self._progress is not None:

~/anaconda3/lib/python3.7/site-packages/pymc3/ in recv_draw(processes, timeout)
    247             else:
    248                 error = RuntimeError("Chain %s failed." % proc.chain)
--> 249             six.raise_from(error, old_error)
    250         elif msg[0] == "writing_done":
    251             proc._readable = True

~/anaconda3/lib/python3.7/site-packages/ in raise_from(value, from_value)

RuntimeError: Chain 1 failed.

I’m not sure why it should find a region with gradient of zero anywhere, given my log-likelihood function. Is there a way to have it record what values of the parameters give the bad gradients so I can troubleshoot this? Any other insights as to why Pymc3 is unhappy with this would be appreciated too.

Seems that your logp implementation breaks the gradient:

import theano.tensor as tt
dgamma, dzeta, dxi = tt.grad(logl, [pm_gamma, pm_zeta, pm_xi])
inputdict = {pm_gamma: mle['x'][0], pm_zeta:mle['x'][1], pm_xi:mle['x'][2]}
print(logl.eval(inputdict), dgamma.eval(inputdict), dzeta.eval(inputdict), dxi.eval(inputdict))
# ==> -418.7280097633918 nan nan nan

It is likely the line

logls = pm.math.log(qxus ** changes) - (qxus * dts[:,None])

Hmmm, usually not having gradient is not a deal breaker (e.g., see comment in, but in this case it certainly does not help the debugging.

[edit:] not having gradient for all RVs is indeed a deal breaker.

[edit 2]:
OK, so I figure out a version that works:

def logp(gamma, zeta, xi):
    qxus = ((is_s * n_infected[:, None] * gamma) +
            (is_l * zeta) + (is_i * xi))
    x = tt.zeros(qxus.shape)
    x = tt.inc_subtensor(x[changes], qxus[changes])
    x = tt.inc_subtensor(x[~changes], 1.)

    logls = pm.math.log(x) - (qxus * dts[:, None])
    return pm.math.sum(logls)

with pm.Model() as the_model:
    gamma = pm.HalfNormal("gamma", sd=1)
    zeta = pm.HalfNormal("zeta", sd=1)
    xi = pm.HalfNormal("xi", sd=1)

    potential = pm.DensityDist("potential",
                               observed={"gamma": gamma,
                                         "zeta": zeta,
                                         "xi": xi})

For full debugging note see:


Oh wow, I wasn’t expecting you to fix the whole thing. Thank you so much!

If it isn’t too much trouble, could you explain why your version has a working gradient and mine doesn’t? I’m not that familiar with theano and how it works under the hood. I intend to eventually expand the model to incorporate disease dynamics that depend on covariate information (e.g. pairwise properties like “lives in the same household”) and it will involve adding a bunch more similar terms to the model, and I want to be confident enough to do it properly.

From reading it, it seems to have something to do with the log(qxus^changes) term, but I don’t know the significance of the subtensors in your implementation.

Yes - for some reason, this operation breaks the gradient! Similarly if you check the notebook, operation like tt.switch also breaks the gradient. Then the question becomes how to re-express that part into something that does not breaks the gradient. And doing subtensor (basically constructing that tensor piece by piece) seems to work

1 Like