DDMs are a class of decision-making models found in cognitive neuroscience. Given their incredible popularity, I think it would be a major boon to implement DDMs in PyMC3. Unfortunately, this is somewhat tricky due to the complexity of the DDM likelihood function. I see there have been at least one previous effort to do this, but I don’t see any complete examples.
There are numerous R, Stan, and Python (of course) implementations of the DDM already, but all of them are somewhat niche or quirky and none really work that well for my specific purposes. There is already a very popular DMM implementation called HDDM, which is great, but it still uses PyMC2 and is somewhat challenging to install these days.
In my view, the simplest way to get a working DDM in PyMC3 is to port the log likelihood function from HDDM to a Theano Op and wrap it in a pm.DensityDist
, or use it as an aritbrary pm.Potential
. This should at least allow gradient-free MCMC, such as slice sampling, replicating the functionality of HDDM.
I took @twiecki 's Cython code from HDDM and wrote it as pure Python; see the Gist below. The main result is the logpdf_with_contaminant_exponential
function, that gives the log likelihood for a single data point x
, given four “basic” DDM parameters, three trial-by-trial variability parameters, and two contaminant parameters.
I also wrapped logpdf_with_contaminant_exponential
as a vectorized function with Numba. On the whole it seems to work well, and evaluates quickly.
At this point, I’m stumped. I can’t figure out how to convert my Numba function into a Theano Op. It seems to be possible, but I’m just understanding how this example works.
I’ve also read through this example in the PyMC3 docs several times, and again I’m not getting it. I can’t seem to wrap the pure Python implementation of logpdf_with_contaminant_exponential
either.
Perhaps it makes more sense to just copy-paste in the original Cython code as wrap that instead? I didn’t do this originally because I wanted to understand how those functions worked, and (naively) thought it would be easier to wrap pure Numba than Cython, because the former is compatible with NumPy ufuncs.
I would greatly welcome comments and assistance!