Bringing the drift-diffusion model (DDM) to PyMC3

Adding a bit more information…

When only specifying q = pm.Uniform(name="p", lower=0, upper=0.1), the mixture model sampling is pretty unstable, and lead to a lot of divergence. See the plot below, the orange chain is strongly misled by the contaminant exponential distribution. I guess the gradient estimation for the mixture model switching (between ddm and contaminant) is not very reliable and lead to a high estimation value for q. I wonder if using the build-in Mixture distribution instead of using the custom mixture logp may help this issue?

Here is the code for data generation and model building:

v = 1
sv = 0
a = 0.8
z = 0.5
sz = 0.0
t = 0.1
st = 0.
q = 0.05
l = 0.5
r = 0.8
size = 3000

x = aesara_wfpt_con_rvs(v, sv, a, z, sz, t, st, q, l, r, -20, 20, size)

with pm.Model() as model:
    q = pm.Uniform(name="q", lower=0, upper=0.1)
    l = pm.Uniform(name="l", lower=0, upper=1)
    r = pm.Uniform(name="r", lower=0, upper=1)

    lower = min(x)
    upper = max(x)
    
    a = pm.Gamma(name="a", mu=0.8, sigma=0.5)
    v = pm.Normal(name="v", mu=0, sigma=1)
    t = pm.Gamma(name="t", mu=0.1, sigma=0.1)
    
    WFPT(name="x", v=v, sv=sv, a=a, z=z, sz=sz, t=t, st=st, q=q, l=l, r=r, lower=lower, upper=upper, observed=x)
    
    results = pm.sample(5000, return_inferencedata=True)
    az.plot_trace(results)
    plt.show()

Here is the hist for the data:
Screenshot from 2021-07-18 18-01-10

1 Like

First, note that uniform, not exponential, is the typical contaminant distribution choice. I was simply trying out exponential to avoid having to set a max RT.

I think that l, the shape parameter of the exponential, is the culprit. If you look at the exponential distribution you might intuitively see how l can trade off with q in these models. I would try setting l to some realistic value, or switching out exponential for uniform.

Hi, using stan to fit the ddm is quite popular, however there is also the same problem in stan. non-decision time is smaller than the RT, causing the bad initial energy. So, is this a common problem for ddm?

A general problem, with a known solution.

Hi, sorry for asking this stupid question. What’s the general solution apart from setting the initial value of non-decision time as a relatively small value? Thanks.

@aetius You can do a mixture with a uniform outlier distribution. You can see how we do it in HDDM here: hddm/wfpt.pyx at master · hddm-devs/hddm · GitHub

1 Like

Yes, using a uniform mixture.

Hi, I am trying to use your code to fit the “pure DDM” without between trial variability parameters in pymc3, when I implement a non-hierarchical model, the fitting procedure looks fine, however hierarchical model can not be properly initialized, there is always the ‘bad energy’ warning. What would cause this error? Thank you very much.

@sammosummo Thanks so much for your contribution. Just curious, your ddm_in_aerasa.py import function from in_jax. Just to confirm do you mean your ddm_in_jax.py here?? I found no jax_wfpt_rvs function there. I only found jax_wfpt_pdf_sv. Can you point me how to run your code??

Also, @jasongong11 it looks like you have successfully sample from WFPT? can you share your examples??

Thanks so much!

Honestly I have no clue, sorry.

The idea was to develop a submodule for the DDM and add it to PyMC via a PR, similar to glm. I wanted to do this because I have a bunch of experiments I could apply the DDM to, and wasn’t happy with the direction other Python packages (HDDM and PyDDM) were taking at the time.

I started to write the code almost two years ago at this point. I hit a couple of hurdles (documented above) then had to stop due to work and family commitments. I would love to finish it, but it looks unlikely that my schedule will ever get lighter again.

Dear DDM experts, @sammosummo @twiecki

I did some work on this. Note that pymc has been updated to 5. and it uses pytensor. Some original pieces are not useable.

based on the above file, I can correct sample a single subject’s data via pm.CustomDist() (see below)

Sample one subject

from ddm_in_pytensor import *

# generate data
# generate data from two subjects
v = 1 # drift rate
sv = 0 
a = 0.8 # boundary
z = 0.5 # starting point
sz = 0.0
t = 0.01 # non decision time
st = 0.0
q = .02
l = 0.5
r = 0.8
size = 300

x = aesara_wfpt_rvs(v1, sv, a, z, sz, t, st, q, l, r, size)

with pm.Model() as model:
    v = pm.HalfNormal(name="drift", sigma=5)
    a = pm.HalfNormal(name='bound', sigma=5)
    z = pm.Beta(name="startingPoint", alpha=4, beta=4)
    t = pm.HalfNormal(name="ndt", sigma=0.02)
    pm.CustomDist("data", v, sv, a, z, sz, t, st, q, l, r, \
                  logp=aesara_wfpt_log_like, \
                  random=aesara_wfpt_rvs,\
                  observed=x)
    results = pm.sample(3000, initvals={'drift':np.array(1.), 'startingPoint':np.array(0.5), 'bound':np.array(0.8), 'ndt':0.01})

However, I tried to build a hierarchical model to sample multi-subject data. However, building the model reports errors in broadcasting.

# generate data from two subjects
v1 = 1 # drift rate of subject 1
v2 = 2 # drift rate of subject 2
sv = 0
a = 0.8 # boundary
z = 0.5 # starting point
sz = 0.0
t = 0.0 # non decision time
st = 0.
q = .02
l = 0.5
r = 0.8
size = 300

x1 = aesara_wfpt_rvs(v1, sv, a, z, sz, t, st, q, l, r, size)
x2 = aesara_wfpt_rvs(v2, sv, a, z, sz, t, st, q, l, r, size)

import pandas as pd
data = pd.DataFrame({'RT': np.hstack((x1, x2)), 'subj': np.hstack((np.ones(size), np.ones(size)*2))})
subjIdx, uniqueSubj = pd.factorize(data.subj)
coords = {
    'subj': uniqueSubj,
    'obs_id': np.arange(data.subj.size)
}

with pm.Model(coords=coords) as hmodel: # initiate hierarchical model 
    
    subjIdx = pm.ConstantData('subjIdx', subjIdx, dims='obs_id')
    
    # initiate hyper distributions
    μ_v = pm.Uniform(name="μ_v", lower=0.1, upper=5) # drift rate
    σ_v = pm.HalfNormal(name="σ_v", sigma=5)
#    μ_a = pm.Uniform(name="μ_a", lower=0.01, upper=5) # bound
#    σ_a = pm.HalfNormal(name="σ_a", sigma=5)
    
#    μ_beta = pm.Uniform(name="μ_beta", lower=2, upper=10) # starting point
    
#    σ_ndt = pm.HalfNormal(name="σ_ndt", sigma=5) # ndt
    
    # initiate hyper distribution
    v = pm.Gamma(name='v', mu=μ_v, sigma=σ_v, dims="subj", initval=np.array([1, 1]))

#    a = pm.Gamma(name='a', mu=0.8, sigma=1, initval=0.8)
#    z = pm.Beta(name="z", alpha=4, beta=4, initval=0.5)
#    t = pm.HalfNormal(name="ndt", sigma=0.01, initval=0.01)
    
#    eps = pm.HalfNormal('eps', sigma=10)
    
#    y = pm.Normal('data', mu=v[subjIdx], sigma=eps, observed=data.RT, dims='obs_id') # test likelihood fun
    
      y = pm.CustomDist("data", v[subjIdx], sv, a, z, sz, t, st, q, l, r, \
                  logp=aesara_wfpt_log_like, \
                  random=aesara_wfpt_rvs,\
                  observed=data.RT, 
                  dims='obs_id')

It is hard for me to comprehend the error or to fix it:

…/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py:169: UserWarning: Failed to infer_shape from Op Elemwise{mul,no_inplace}.
Input shapes: [(TensorConstant{48000},), (TensorConstant{600},)]
Exception encountered during infer_shape: <class ‘ValueError’>
Exception message: Could not broadcast dimensions
Traceback: Traceback (most recent call last):
File “…/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py”, line 145, in get_node_infer_shape
o_shapes = shape_infer(
^^^^^^^^^^^^
File “…/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/elemwise.py”, line 835, in infer_shape
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “…/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/extra_ops.py”, line 1455, in broadcast_shape
return broadcast_shape_iter(arrays, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “…/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/extra_ops.py”, line 1536, in broadcast_shape_iter
raise ValueError(“Could not broadcast dimensions”)
ValueError: Could not broadcast dimensions

warn(msg)

1 Like