Bringing the drift-diffusion model (DDM) to PyMC3

DDMs are a class of decision-making models found in cognitive neuroscience. Given their incredible popularity, I think it would be a major boon to implement DDMs in PyMC3. Unfortunately, this is somewhat tricky due to the complexity of the DDM likelihood function. I see there have been at least one previous effort to do this, but I don’t see any complete examples.

There are numerous R, Stan, and Python (of course) implementations of the DDM already, but all of them are somewhat niche or quirky and none really work that well for my specific purposes. There is already a very popular DMM implementation called HDDM, which is great, but it still uses PyMC2 and is somewhat challenging to install these days.

In my view, the simplest way to get a working DDM in PyMC3 is to port the log likelihood function from HDDM to a Theano Op and wrap it in a pm.DensityDist, or use it as an aritbrary pm.Potential. This should at least allow gradient-free MCMC, such as slice sampling, replicating the functionality of HDDM.

I took @twiecki 's Cython code from HDDM and wrote it as pure Python; see the Gist below. The main result is the logpdf_with_contaminant_exponential function, that gives the log likelihood for a single data point x, given four “basic” DDM parameters, three trial-by-trial variability parameters, and two contaminant parameters.

I also wrapped logpdf_with_contaminant_exponential as a vectorized function with Numba. On the whole it seems to work well, and evaluates quickly.

At this point, I’m stumped. I can’t figure out how to convert my Numba function into a Theano Op. It seems to be possible, but I’m just understanding how this example works.

I’ve also read through this example in the PyMC3 docs several times, and again I’m not getting it. I can’t seem to wrap the pure Python implementation of logpdf_with_contaminant_exponential either.

Perhaps it makes more sense to just copy-paste in the original Cython code as wrap that instead? I didn’t do this originally because I wanted to understand how those functions worked, and (naively) thought it would be easier to wrap pure Numba than Cython, because the former is compatible with NumPy ufuncs.

I would greatly welcome comments and assistance!


You may not be surprised to hear that I think this is an excellent idea ;).

Here is an example of using a numba function in a Theano Op:

@numba.njit(parallel=True, fastmath=True)
def numba_logsumexp(p, out):
    n, m = p.shape
    assert len(out) == n
    assert out.ndim == 1
    assert p.ndim == 2

    for i in numba.prange(n):
        res = 0
        for j in range(m):
            res += np.exp(p[i, j])
        res = np.log(res)
        out[i] = res

class LogSumExp(theano.graph.op.Op):
    """Custom softmax, done through logsumexp"""

    itypes = [tt.dmatrix]
    otypes = [tt.dvector]

    def perform(self, node, inputs, outputs):
        x, = inputs
        n, m = x.shape
        out = np.zeros(n, dtype=x.dtype)
        numba_logsumexp(x, out)
        outputs[0][0] = out

    def grad(self, inputs, grads):
        x, = inputs
        dout, = grads
        logsumexp = self(x)
        return [LogSumExpGrad()(x, logsumexp, dout)]

As you currently don’t have a grad we can remove that here. I do think we should try and figure this out. Perhaps using JAX we can autotrace the grad.

Let me know if that helps.

Thanks! Regarding calculating the grad, I agree this should be done, but right now I can’t crack the basic problem of getting the likelihood into PyMC3. Here’s what I tried. The log likelihood function looks like this.

def full_ddm_loglike(x, v, sv, a, z, sz, t, st, p_outlier, l):
    if (
            (z < 0) or (z > 1) or (a < 0) or (t < 0)
            or (st < 0) or (sv < 0) or (sz < 0) or (sz > 1)
            or (z + sz / 2.0 > 1) or (z - sz / 2.0 < 0)
            or (t - st / 2.0 < 0) or (p_outlier < 0)
            or (p_outlier > 1)
        return -inf
    if p_outlier == 0:
        return log(full_pdf(x, v, sv, a, z, sz, t, st))
        if l <= 0:
            return -inf
        p0 = full_pdf(x, v, sv, a, z, sz, t, st) * (1 - p_outlier)
        p1 = pdf_contaminant_exponential(x, l) * p_outlier
        return log(p0 + p1)

Then I define a new Theano Op class:

class FullDDMLogLike(theano.graph.op.Op):

    itypes = [tt.dmatrix]
    otypes = [tt.dvector]

    def perform(self, node, inputs, outputs):
        out = full_ddm_loglike(*inputs)
        outputs[0][0] = out

Then when I put this into a PyMC3 model:

with pm.Model():

    x = np.ones(100)  # data
    v = pm.Normal("v")  # one stochastic to keep things simple
    sv, a, z, sz, t, st = 0.01, 1.0, 0.5, 0.01, 0.5, 0.01  # fixed ddm params
    p_outlier, l = 0.001, 0.1  # fixed contaminate params

    like = pm.DensityDist('like', FullDDMLogLike, observed=(x,))

I get the error:

TypeError: FullDDMLogLike() takes no arguments

I think you may need to instantiate the object (FullDDMLogLike()).

However – and excuse my ignorance – what is the difference between the DDM and a random walk? Is there a problem with fixing the boundaries, placing all the observations at the boundary, and simply using a GaussianRandomWalk? (fixing the boundary should be OK, since the same hitting time distributions should be recoverable after re-scaling the drift/diffusion params)

Thanks! I’ve made a bit more progress today—looks like its actually sampling now!—and this was definitely one of the issues. I aim to make another post in this thread later today.

I haven’t looked at the GaussianRandomWalk implementation, but the DDM is much more complicated than just Wiener diffusion/random walk due to the fact that the “basic” diffusion process has to augmented by trial-by-trial variability parameters to capture general features seen in real datasets.

GaussianRandomWalk would be very inefficient as it’s simulating all random-walks. The wfpt likelihood is a closed-form solution to the boundary terminations and integrates over individual random-walks.

@sammosummo Once you have it running it should be pretty easy to turn this into a proper likelihood distribution so that it could be used with observed instead of a Potential.

Does that impact the likelihood in a non-factorizable way, or would it be the equivalent of having trial-indexed offsets to the RW parameters (i.e. mu_drift + mu_trial_offset, disp + disp_trial_offset)?

Seeing numerical integration implemented in theano really pounds on my “there’s got to be a better way” drum.

I don’t think explicitly modeling the trial-by-trial offsets of the DDM parameters is the right approach under most circumstances. Typically one has tens of thousands of trials, possibly hundreds of thousands, so the model ends up with one, two, or three times as many RVs as trials.

I’ve tried this with other psychophysical models. It does work. You specify mean and variance RVs for the trial-varying parameter, then include trial-by-trial parameter RVs that are estimated hierarchically. The trial-by-trial RVs essentially take on their prior distribution since there is hardly any evidence with which to update them, but it does propagate the uncertainty (if that’s the right phrase) to the all-trials level. But practically the model is useless; it takes many orders of magnitude longer to sample and because the trial-by-trial RVs are all over the place, LOO, WAIC etc. are all screwed up.

This makes sense – though integrating out the random walk shouldn’t “fix” any issues with model specification. I’m honestly surprised that sampling should be so much slower than explicit numerical integration - the whole purpose of sampling is that we can’t perform numerical integration in high dimensions. I suppose in this case the integral is 1-d, but the walk itself is (# of timepoints)-d.

Thank you (both) for helping to clarify.

After reading Navarro and Fuss and @twiecki’s Cython code again, and trying to get working JAX code, I’m beginning to think the grad can’t be autotraced for Navarro and Fuss’ fast WFPT likelihood function.

The Navarro and Fuss likelihood function calculates a minimum number of terms \kappa in order to achieve some minimum acceptable error for two different candidate functions (small-time and large-time), then adaptively uses whichever WFPT has the smallest \kappa. So in the final function there are numerous embedded if/else statements, as well as a for loop in the selected candidate function whose range depends on the inputs.

In pure Python, coding up an elementwise likelihood function poses no problem at all. Moreover, I can JIT the candidate functions, as well as vectorize-JIT the likelihood function with Numba. In the end, the compiled likelihood function is super fast and can accept scalars as well as vectors. I successfully wrapped this in a Theano Op (although the wrapped Op only accepted Theano vectors, not scalars, don’t understand why).

I tried to recapitulate this steps using JAX rather than Numba, but JAX doesn’t seem to be able to trace loops with variable lengths, so the grad can’t be computed. I’m wondering whether this is a logical impossibility or whether I simply don’t know the right trick in JAX.

I also tried building the likelihood function in Theano, but as of right now a fresh install of Theano won’t even compile on my machine.

Sounds like great progress. I think the JAX way needs a different approach:

  • Always evaluate both directions (for fast and slow RTs)
  • Fix the number of loop iterations to something high enough (probably 7 is good enough, but this could be tested)

That way there is wasted computation but I have the hope that this is offset (at least partially) by vectorization. For that likelihood, gradients should be inferrable by JAX.

OK, we are making progress!

I now have a JITable log likelihood function that produces seemingly sensible values.

What in your opinion is the best way to validate this function (that is, check it is calculating the log likelihood correctly)?

Awesome, want to share the code? Curious to take a look. Best to compare against the original implementation, or just borrow the lookup table from the HDDM tests: hddm/ at master · hddm-devs/hddm · GitHub and hddm/ at master · hddm-devs/hddm · GitHub

Of course, what’s the best way? I could create a new GitHub repo or a Gist, or just paste it here. (I’m still more scientist than coder at this point).

Gist would be easiest.

1 Like

This is awesome, can you get the gradient?

I can wrap jax_wfpt_sumlogp in jax.grad() without error but it only returns nans sometimes produces nans erroneously. I think there is an issue somewhere, perhaps when numerically integrating out sz and st—that’s the part I’m least sure of with currently.

The ps and pl calculations might cause some trouble with under or overflow. Could placing the terms into an array and using the equivalent of pm.math.logsumexp be helpful?

1 Like

Thinking more about this, you could probably do that same calculation in aesara directly and get gradients for all backends that way.