Efficient Sampling of Multinomial Processing Tree using NUTS?

Hi all!

I’m just getting started teaching myself some PyMC3 and looking to re-implement a JAGS Multinomial Processing Tree (MPT) model from the “Computational Modeling of Cognition and Behavior” textbook by Farrell & Lewandowsky (2018). The model is based on Waagenar & Boer’s (1987) ‘No Conflict’ MPT model and developed further in Vandekerckhove et al (2015). I’m happy to provide more specific model details, but I think the issue might be more around how I’ve coded it up.

My code so far more or less follows the JAGS code line-for-line (as linked above), and looks like this:

import pymc3 as pm
import numpy as np

# Import Data
obs_consistent = np.array([78, 70, 7, 15])
obs_inconsistent = np.array([102, 55, 40, 53])
obs_neutral = np.array([63, 45, 13, 21])

with pm.Model() as MPT_Noconflict_model:
    # Priors for  MPT Parameters
    p = pm.Beta('p', alpha = 1, beta = 1)
    q = pm.Beta('q', alpha = 1, beta = 1)
    c = pm.Beta('c', alpha = 1, beta = 1)
    
    # Predicted probabilities for:
    ## Consistent Condition
    c_pp = (1 + p + q - p*q + 4 * p*c)/6
    c_pm = (1 + p + q - p*q - 2 * p*c)/3
    c_mp = (1 - p - q + p*q)/6
    c_mm = (1 - p - q + p*q)/3
    
    ## Inconsistent Condition
    i_pp = (1 + p - q + p*q + 4 * p*c)/6
    i_pm = (1 + p - q + p*q - 2 * p*c)/3
    i_mp = (1 - p + q - p*q)/6
    i_mm = (1 - p + q - p*q)/3
    
    ## Neutral Condition
    n_pp = (1 + p + 4 * p*c)/6
    n_pm = (1 + p - 2 * p*c)/3
    n_mp = (1 - p)/6
    n_mm = (1 - p)/3
    
    # Observed from Multinomial as fn of predicted probs
    consistent = pm.Multinomial('Consistent', p = [c_pp, c_pm, c_mp, c_mm], n = 170, observed = obs_consistent)
    inconsistent = pm.Multinomial('Inconsistent', p = [i_pp, i_pm, i_mp, i_mm], n = 250, observed = obs_inconsistent)
    neutral = pm.Multinomial('Neutral', p = [n_pp, n_pm, n_mp, n_mm], n = 142, observed = obs_neutral)
    
    # Sample using NUTS
    # step = pm.Metropolis(vars = [p, q, c]) # Uncomment for Metropolis
    MPT_Noconflict_trace = pm.sample(5000, init = 'adapt_diag')

Using the Metropolis algorithm, this model seems to sample ok in under 3 minutes or so, without divergences, and reproducing the results in the book nicely. Using PyMC3’s default NUTS algorithm however leads to incredibly slow sampling, rejection of chains, and thousands of divergences.

How can I go about improving this model’s NUTS sampling?

My first suspicion is that assigning all my probability parameters separately to feed the multinomial likelihood terms may be inefficient, but I’m not sure how I could do this better. I’ve written this in a model similar manner to the existing BCM Notes for MPTs, which seemed to work ok (albeit with fewer branches), and can’t find a directly applicable solution in the FAQ’s.

Any help or guidance would be appreciated!

I think the model works poorly in NUTS is that the prior is too vague, and since there is not a lot of information from the likelilhood, NUTS/HMC is likely going to struggle in the tail. MH might work better indeed in such low dimensional problem, but it does not have diagnostics to tell you when it doesnt.

1 Like

Actually, calling MPT_Noconflict_trace = pm.sample(5000) directly seems to work fine - are you using the newest release?

Good point r.e. the priors - they’re effectively uniform in the text. Just tried re-running it with Beta(2,2), but I’m encountering extremely slow sampling once again.

I’m running this on a fresh install of miniconda from last night, in a virtual environment using PyMC3 v.3.9.2 and Theano v.1.04.

I cant seems to be able to reproduce it :-/

This model ran for me in a few seconds with 5000 draws. I’m on PyMC3 from master branch.
The posteriors are all centered on 0.5 though, is that expected?

1 Like

Upon investigation the issue was installing Theano and having it ‘play nice’ with PyMC3 on Windows 10, rather than problems with the model itself.

Depending on whether I tried installing everything via conda, conda-forge (e.g. theano, m2w64-toolchain, mkl-service, etc), or manually installing MinGW for the g++ compiler (as suggested here), I variously encountered errors very similar to those described here and here. Typically the model would stall at 3.40% or so, regardless of the number of cores, chains, priors, inits.

My current solution is not ideal but working well, sampling within 15 seconds and yielding results similar to @nkaimcaudle (thanks btw - that looks correct as per the text!). I followed these instructions, setting up a virtual environment using somewhat older versions of Python and PyMC3 (both 3.5). EDIT: Actually, even this is only working with one core. Trying to add more throws this error :upside_down_face:.

I’ll keep trying to fix the issue in a separate virtual env (just came across another post which may be helpful) and will report back if I find any solutions, but setting things up on Windows 10 unfortunately remains a bit of a headache for now. Thank you thank you @junpenglao, and super excited for PyMC4!

@aseyboldt is a working on a solution to make installation and env set up more stable in winOS for multi sampling, hopefully it will help a bit - stay tuned😅

1 Like