I’m trying to convert a Bayesian JAGS model to PyMC, but I’m having problems with the sampling. Here is the model:
def model_1v_LG2D(data):
with pm.Model() as model:
# Convert data to tensor variables
y = pt.as_tensor_variable(data['y'], ndim = 2)
d_phy = pt.as_tensor_variable(data['d_phy'], ndim = 2)
r = pt.as_tensor_variable(data['r'], ndim = 2)
# Hyperpriors
lambda_mu = pm.TruncatedNormal('lambda_mu', mu = 0.1, sigma = 1, lower = 0, initval = 0.1)
lambda_sigma = pm.Uniform('lambda_sigma', lower = 1e-9, upper = 1, initval = 0.001)
alpha_mu = pm.Truncated('alpha_mu', dist = pm.Beta.dist(alpha = 1, beta = 1), lower = 1e-9, upper = 1 - 1e-9, initval = 0.5)
alpha_kappa = pm.Uniform('alpha_kappa', lower = 1, upper = 10, initval = 5)
sigma = pm.Uniform('sigma', lower = [1e-9, 1.5], upper = [1.5, 3], shape = 2, initval = [0.001, 2])
pi = pm.Dirichlet('pi', a = np.array([1, 1, 1, 1]), initval = [0.25, 0.25, 0.25, 0.25])
w0_mu = pm.Normal('w0_mu', mu = 0, sigma = 0.1, initval = 0)
w0_sigma = pm.Gamma('w0_sigma', alpha = 2, beta = 1, initval = 2)
w1_a = pm.Truncated('w1_a', dist = pm.Gamma.dist(alpha = 2, beta = 1), lower = 1e-9, initval = 1)
w1_b = pm.Truncated('w1_b', dist = pm.Gamma.dist(alpha = 2, beta = 1), lower = 1e-9, initval = 1)
# Latent group indicators
gp_1 = pm.Categorical('gp_1', p = pi, shape = data['Nparticipants'])
gp = pm.Deterministic('gp', gp_1 + 1)
# Priors
lambda_1 = pm.TruncatedNormal('lambda_1', mu = lambda_mu, sigma = lambda_sigma, lower = 0.0052, shape = data['Nparticipants'], initval = np.repeat(0.1, data['Nparticipants']))
lambda_2 = pm.TruncatedNormal('lambda_2', mu = lambda_mu, sigma = lambda_sigma, lower = 0, upper = 0.0052, shape = data['Nparticipants'], initval = np.repeat(0.0026, data['Nparticipants']))
lambda_ = pm.Deterministic('lambda', pm.math.switch(pm.math.eq(gp, 1), 0, pm.math.switch(pm.math.eq(gp, 2), lambda_2, lambda_1)))
w0 = pm.Normal('w0', mu = w0_mu, sigma = w0_sigma, shape = data['Nparticipants'])
w1_1 = pm.Truncated('w1_1', dist = pm.Gamma.dist(alpha = w1_a, beta = w1_b), lower = 1, shape = data['Nparticipants'], initval = np.repeat(1.1, data['Nparticipants']))
w1 = pm.Deterministic('w1', pm.math.switch(pm.math.eq(gp, 1), 0, w1_1))
alpha_1 = pm.Truncated('alpha_1', dist = pm.Beta.dist(alpha = alpha_mu * alpha_kappa, beta = (1 - alpha_mu) * alpha_kappa), lower = 1e-9, upper = 1 - 1e-9, shape = data['Nparticipants'], initval = np.repeat(0.5, data['Nparticipants']))
alpha = pm.math.switch(pm.math.eq(gp, 1), 0, alpha_1)
d_sigma = pm.Truncated('d_sigma', dist = pm.Gamma.dist(alpha = 2, beta = 1), lower = 1e-9, shape = data['Nparticipants'], initval = np.repeat(2, data['Nparticipants']))
zn = pm.Deterministic('zn', pm.math.switch(pm.math.eq(gp, 1), 1, 0))
# Generate missing perceptual distance between CS+ and S
d_per = pm.Normal('d_per', mu = d_phy, sigma = d_sigma[:, None], shape = (data['Nparticipants'], data['Ntrials']))
def trial_update(d_phy_ij, r_ij, d_per_ij, v, theta_prev, g_prev, s_prev, gp_i, lambda_i, w0_i, w1_i, alpha_i):
# Generalization
s = pt.switch((v > 0) & (gp_i > 1), pt.exp(-lambda_i * pt.switch(pt.eq(gp_i, 4), d_per_ij, d_phy_ij)), 1)
g = v * s
# Non-linear transfromation (latent - observed scale)
theta = 1 + (10 - 1) / (1 + pt.exp(-(w0_i + w1_i * g)))
# Learning
# Excitatory associative strength
v_next = pt.switch(pt.neq(gp_i, 1), pt.switch(pt.eq(r_ij, 1), v + alpha_i * (r_ij - v), v), 0)
return [v_next, theta, g, s]
def participant_update(d_phy_i, r_i, d_per_i, gp_i, lambda_i, w0_i, w1_i, alpha_i):
v_init = pt.as_tensor_variable(0.0, dtype='float64') # The initial excitatory associative strength
theta_init = pt.as_tensor_variable(0.0, dtype='float64')
g_init = pt.as_tensor_variable(0.0, dtype='float64')
s_init = pt.as_tensor_variable(0.0, dtype='float64')
# Scan over trials
[v_sequence, theta_sequence, g_sequence, s_sequence], _ = scan(
fn = trial_update,
sequences = [d_phy_i, r_i, d_per_i],
outputs_info = [v_init, theta_init, g_init, s_init],
non_sequences = [gp_i, lambda_i, w0_i, w1_i, alpha_i]
)
return [v_sequence, theta_sequence, g_sequence, s_sequence]
# Scan over participants
[v_sequence, theta_sequence, g_sequence, s_sequence], _ = scan(
fn = participant_update,
sequences = [d_phy, r, d_per, gp, lambda_, w0, w1, alpha]
)
# Convert the output sequences to pymc variables (can be removed)
v = pm.Deterministic('v', v_sequence)
theta = pm.Deterministic('theta', theta_sequence)
g = pm.Deterministic('g', g_sequence)
s = pm.Deterministic('s', s_sequence)
# Likelihood
y_likelihood = pm.Normal('y_likelihood', mu = theta, sigma = pt.repeat(sigma[zn][:, None], repeats = data['y'].shape[1], axis = 1), observed = y, shape = (data['Nparticipants'], data['Ntrials']))
return model
Since I added the likelihood (y_likelihood), the sampling process of 4 chains with 2000 draws has estimated to take over 45 hours. Here is the code I am using to sample:
# Trace
trace = pm.sample(draws = 2000, tune = 2000, chains = 4, cores = 4,
return_inferencedata = True,
compute_convergence_checks = True,
discard_tuned_samples = True)
# Prediction
post_pred = pm.sample_posterior_predictive(trace)
Without the likelihood, the sampling takes about 18 minutes. I can’t use a different sampler from the default one, like Blackjax or NumPyro, because the model contains a discrete variable (gp_1) on which many of the continuous variables depend. The observed data and the parameters of the likelihood are 2D arrays of shape (40, 188), so I understand that the sampling might be slower. However, after 2 days, it is still at 32%. Is it normal for it to be this slow, or is there something fundamentally wrong with the model?
I’m using Windows and running the model in a conda environment created on WSL. My PyMC version is 5.15.1 and PyTensor is 2.22.1.
Any help or advice would be much appreciated.