Bringing the drift-diffusion model (DDM) to PyMC3

That also occurred to me this morning. It should be very easy to do with a few find-and-replaces. I’ll try it.

1 Like

I haven’t encountered any so far.

OK, here’s something.

3 Likes

I spent some of today explicitly on this. Indeed, it was possible to get under/overflows when the RTs were really small, leading to Nans in ps and pl. Following your suggestion, I implemented a version of the logsumexp trick, and it worked. Here is a minimal function that returns what was called ps in the previous code.

import numpy as np

def fnorm_fast(x, w):
    bigk = 7  
    k = (np.arange(bigk) - np.floor(bigk / 2)).reshape(-1, 1)
    y = w + 2 * k
    r = -np.power(y, 2) / 2 / x
    c = np.max(r, axis=0)
    p = np.exp(c + np.log(np.sum(y * np.exp(r - c), axis=0)))
    p = p / np.sqrt(2 * np.pi * np.power(x, 3))
    return p

It is still not perfect. I would really like it to return 0, not nan, when when x is negative. Any ideas on how to achieve that?

3 Likes

This looks fantastic. Have you tried sampling with NUTS?

Not yet, there are still numerous edge cases, such as when RTs are extremely small or large, when z + sz / 2 > 1, and so on. I’m confident I’ll get this working but I think it will take me a week or two (I’m slow).

1 Like

Sounds like those are just invalid parameter combinations where you can do a switch statement.

Did you figure it out? I think Aesara/Theano implementation of the WFTP distribution · GitHub needs to pass tt.

Yes, that was exactly the issue!

OK well, looks like I successfully sampled via NUTS, at least for 1 random variable with a normal prior (drift rate) with all other DDM parameters fixed at their true values. I updated the Gist (it uses another module to generate random samples currently, so not completely self-contained).

It eats memory. My 8GB machine could just about handle it with 1 RV and 3,000 data points, which is not much data for a typical DDM study. Is this just an inevitable consequence of having such a complex likelihood function and parameter space or is there something I can do to reduce memory usage?

When NUTS starts sampling, its pretty fast!

Great progress! I think there should be a way to vectorize these computations (ideal) or implement them in a scan to make the graph smaller and save on memory.

I’m not sure I understand, which parts aren’t currently vectorized?

Ah, you already vectorized a lot since I looked last time. Aesara/Theano implementation of the WFTP distribution · GitHub could probably be a scan loop.

With this, does it still use this much memory? How is inference? Do you plan to put this into a package or contribute to PyMC3?

I’ll try using scan loops, and let you know about memory/inference after working on it some more.

I imagined that the result of all this would simply be a a single module containing a pm.Continuous class for the WPFT distribution that could be dropped into PyMC3. I would rather not create (then have to maintain) an entirely new package.

You can definitely do a PR that adds this distribution.

@sammosummo Any update on this?

Unfortunately not, sorry. Have a couple of other projects to finish up before I can get back to this.

Hi @sammosummo ,

Thank you so much for sharing the gists of the DDM implementation of PyMC3. We recently successfully replicate the code with only v as the stochastic variable.

However, when specifying the t variable, the model failed to sampling and throw the error
Bad initial energy, check any log probabilities that are inf or -inf, nan or very small: x -inf
We think it might be 0 probability (thus -inf logP) when (x - t) is negative.

Any idea to overcome this issue?

Really appreciate it!

Best,
Jason

If your initial guess at t is larger than your shortest RT then yes, for sure you’ll get bad initial energy. Setting the starting value for t to be very small should solve this, but this isn’t the right solution.

The typical approach is to use a mixture model so that RTs can come from a “contaminant” distribution like a uniform(0, big number). I actually implemented this in my local code, perhaps done after I last responded to this thread. I made a bunch more improvements as well.

The project that dragged me away from this is in the home stretch, with a paper about to be submitted. Once that’s done I’ll return to this, hopefully in the not-to-distant future.

1 Like

Hi @sammosummo

Thank you so much for your answer!

I specified the q ~ pm.Uniform(name="q", lower=0, upper=0.2), but it still failed sampling in the middle and throws the error Mass matrix contains zeros on the diagonal. The derivative of RV t_log__.ravel()[0] is zero. The derivative of RV v.ravel()[0] is zero. The derivative of RV a_log__.ravel()[0] is zero. Probably due to overflow?

However, I found that letting t to be a bounded gamma distribution (bounded by some large number like 2) helps this issue. t_subj = pm.Bound(pm.Gamma, upper=2)(name="t_subj", mu=mu_t, sigma=sigma_t, dims="Subj_idx", testval=0.4). I wonder if this sounds like an appropriate solution?

Thank you again! And really excited to hear that you are returning to this contribution! We think this should be a great contribution to add WPFT to the build-in PyMC3 distributions. Looking forward to hear more about your progress! :blush:

1 Like