Hi all!
I’m just getting started teaching myself some PyMC3 and looking to re-implement a JAGS Multinomial Processing Tree (MPT) model from the “Computational Modeling of Cognition and Behavior” textbook by Farrell & Lewandowsky (2018). The model is based on Waagenar & Boer’s (1987) ‘No Conflict’ MPT model and developed further in Vandekerckhove et al (2015). I’m happy to provide more specific model details, but I think the issue might be more around how I’ve coded it up.
My code so far more or less follows the JAGS code line-for-line (as linked above), and looks like this:
import pymc3 as pm
import numpy as np
# Import Data
obs_consistent = np.array([78, 70, 7, 15])
obs_inconsistent = np.array([102, 55, 40, 53])
obs_neutral = np.array([63, 45, 13, 21])
with pm.Model() as MPT_Noconflict_model:
# Priors for MPT Parameters
p = pm.Beta('p', alpha = 1, beta = 1)
q = pm.Beta('q', alpha = 1, beta = 1)
c = pm.Beta('c', alpha = 1, beta = 1)
# Predicted probabilities for:
## Consistent Condition
c_pp = (1 + p + q - p*q + 4 * p*c)/6
c_pm = (1 + p + q - p*q - 2 * p*c)/3
c_mp = (1 - p - q + p*q)/6
c_mm = (1 - p - q + p*q)/3
## Inconsistent Condition
i_pp = (1 + p - q + p*q + 4 * p*c)/6
i_pm = (1 + p - q + p*q - 2 * p*c)/3
i_mp = (1 - p + q - p*q)/6
i_mm = (1 - p + q - p*q)/3
## Neutral Condition
n_pp = (1 + p + 4 * p*c)/6
n_pm = (1 + p - 2 * p*c)/3
n_mp = (1 - p)/6
n_mm = (1 - p)/3
# Observed from Multinomial as fn of predicted probs
consistent = pm.Multinomial('Consistent', p = [c_pp, c_pm, c_mp, c_mm], n = 170, observed = obs_consistent)
inconsistent = pm.Multinomial('Inconsistent', p = [i_pp, i_pm, i_mp, i_mm], n = 250, observed = obs_inconsistent)
neutral = pm.Multinomial('Neutral', p = [n_pp, n_pm, n_mp, n_mm], n = 142, observed = obs_neutral)
# Sample using NUTS
# step = pm.Metropolis(vars = [p, q, c]) # Uncomment for Metropolis
MPT_Noconflict_trace = pm.sample(5000, init = 'adapt_diag')
Using the Metropolis algorithm, this model seems to sample ok in under 3 minutes or so, without divergences, and reproducing the results in the book nicely. Using PyMC3’s default NUTS algorithm however leads to incredibly slow sampling, rejection of chains, and thousands of divergences.
How can I go about improving this model’s NUTS sampling?
My first suspicion is that assigning all my probability parameters separately to feed the multinomial likelihood terms may be inefficient, but I’m not sure how I could do this better. I’ve written this in a model similar manner to the existing BCM Notes for MPTs, which seemed to work ok (albeit with fewer branches), and can’t find a directly applicable solution in the FAQ’s.
Any help or guidance would be appreciated!