Very slow sampling of hierarchical, non-linear bayesian model after adding likelihood function

I’m trying to convert a Bayesian JAGS model to PyMC, but I’m having problems with the sampling. Here is the model:

def model_1v_LG2D(data):
    with pm.Model() as model:
        
        # Convert data to tensor variables
        y = pt.as_tensor_variable(data['y'], ndim = 2)
        d_phy = pt.as_tensor_variable(data['d_phy'], ndim = 2)
        r = pt.as_tensor_variable(data['r'], ndim = 2)
        
        # Hyperpriors
        lambda_mu = pm.TruncatedNormal('lambda_mu', mu = 0.1, sigma = 1, lower = 0, initval = 0.1)
        lambda_sigma = pm.Uniform('lambda_sigma', lower = 1e-9, upper = 1, initval = 0.001)
        alpha_mu = pm.Truncated('alpha_mu', dist = pm.Beta.dist(alpha = 1, beta = 1), lower = 1e-9, upper = 1 - 1e-9, initval = 0.5)
        alpha_kappa = pm.Uniform('alpha_kappa', lower = 1, upper = 10, initval = 5)
        sigma = pm.Uniform('sigma', lower = [1e-9, 1.5], upper = [1.5, 3], shape = 2, initval = [0.001, 2])
        pi = pm.Dirichlet('pi', a = np.array([1, 1, 1, 1]), initval = [0.25, 0.25, 0.25, 0.25])
        w0_mu = pm.Normal('w0_mu', mu = 0, sigma = 0.1, initval = 0)
        w0_sigma = pm.Gamma('w0_sigma', alpha = 2, beta = 1, initval = 2)
        w1_a = pm.Truncated('w1_a', dist = pm.Gamma.dist(alpha = 2, beta = 1), lower = 1e-9, initval = 1)
        w1_b = pm.Truncated('w1_b', dist = pm.Gamma.dist(alpha = 2, beta = 1), lower = 1e-9, initval = 1)

        # Latent group indicators
        gp_1 = pm.Categorical('gp_1', p = pi, shape = data['Nparticipants'])
        gp = pm.Deterministic('gp', gp_1 + 1)
        
        # Priors
        lambda_1 = pm.TruncatedNormal('lambda_1', mu = lambda_mu, sigma = lambda_sigma, lower = 0.0052, shape = data['Nparticipants'], initval = np.repeat(0.1, data['Nparticipants']))
        lambda_2 = pm.TruncatedNormal('lambda_2', mu = lambda_mu, sigma = lambda_sigma, lower = 0, upper = 0.0052, shape = data['Nparticipants'], initval = np.repeat(0.0026, data['Nparticipants']))
        lambda_ = pm.Deterministic('lambda', pm.math.switch(pm.math.eq(gp, 1), 0, pm.math.switch(pm.math.eq(gp, 2), lambda_2, lambda_1)))
        w0 = pm.Normal('w0', mu = w0_mu, sigma = w0_sigma, shape = data['Nparticipants'])
        w1_1 = pm.Truncated('w1_1', dist = pm.Gamma.dist(alpha = w1_a, beta = w1_b), lower = 1, shape = data['Nparticipants'], initval = np.repeat(1.1, data['Nparticipants']))
        w1 = pm.Deterministic('w1', pm.math.switch(pm.math.eq(gp, 1), 0, w1_1))
        alpha_1 = pm.Truncated('alpha_1', dist = pm.Beta.dist(alpha = alpha_mu * alpha_kappa, beta = (1 - alpha_mu) * alpha_kappa), lower = 1e-9, upper = 1 - 1e-9, shape = data['Nparticipants'], initval = np.repeat(0.5, data['Nparticipants']))
        alpha = pm.math.switch(pm.math.eq(gp, 1), 0, alpha_1)
        d_sigma = pm.Truncated('d_sigma', dist = pm.Gamma.dist(alpha = 2, beta = 1), lower = 1e-9, shape = data['Nparticipants'], initval = np.repeat(2, data['Nparticipants']))
        zn = pm.Deterministic('zn', pm.math.switch(pm.math.eq(gp, 1), 1, 0))
        
        # Generate missing perceptual distance between CS+ and S
        d_per = pm.Normal('d_per', mu = d_phy, sigma = d_sigma[:, None], shape = (data['Nparticipants'], data['Ntrials']))
        
        def trial_update(d_phy_ij, r_ij, d_per_ij, v, theta_prev, g_prev, s_prev, gp_i, lambda_i, w0_i, w1_i, alpha_i):
            
            # Generalization   
            s = pt.switch((v > 0) & (gp_i > 1), pt.exp(-lambda_i * pt.switch(pt.eq(gp_i, 4), d_per_ij, d_phy_ij)), 1)
            g = v * s
            
            # Non-linear transfromation (latent - observed scale)
            theta = 1 + (10 - 1) / (1 + pt.exp(-(w0_i + w1_i * g)))
            
            # Learning
            # Excitatory associative strength
            v_next = pt.switch(pt.neq(gp_i, 1), pt.switch(pt.eq(r_ij, 1), v + alpha_i * (r_ij - v), v), 0)

            return [v_next, theta, g, s]
        
        def participant_update(d_phy_i, r_i, d_per_i, gp_i, lambda_i, w0_i, w1_i, alpha_i):
            
            v_init = pt.as_tensor_variable(0.0, dtype='float64') # The initial excitatory associative strength
            theta_init = pt.as_tensor_variable(0.0, dtype='float64') 
            g_init = pt.as_tensor_variable(0.0, dtype='float64') 
            s_init = pt.as_tensor_variable(0.0, dtype='float64')
            
            # Scan over trials
            [v_sequence, theta_sequence, g_sequence, s_sequence], _ = scan( 
                fn = trial_update,
                sequences = [d_phy_i, r_i, d_per_i],
                outputs_info = [v_init, theta_init, g_init, s_init],
                non_sequences = [gp_i, lambda_i, w0_i, w1_i, alpha_i]
            )
            
            return [v_sequence, theta_sequence, g_sequence, s_sequence]
        
        # Scan over participants
        [v_sequence, theta_sequence, g_sequence, s_sequence], _ = scan(
            fn = participant_update,
            sequences = [d_phy, r, d_per, gp, lambda_, w0, w1, alpha]
        )
        
        # Convert the output sequences to pymc variables (can be removed) 
        v = pm.Deterministic('v', v_sequence)
        theta = pm.Deterministic('theta', theta_sequence)
        g = pm.Deterministic('g', g_sequence)
        s = pm.Deterministic('s', s_sequence)
        
        # Likelihood
        y_likelihood = pm.Normal('y_likelihood', mu = theta, sigma = pt.repeat(sigma[zn][:, None], repeats = data['y'].shape[1], axis = 1), observed = y, shape = (data['Nparticipants'], data['Ntrials']))

        return model

Since I added the likelihood (y_likelihood), the sampling process of 4 chains with 2000 draws has estimated to take over 45 hours. Here is the code I am using to sample:

 # Trace
 trace = pm.sample(draws = 2000, tune = 2000, chains = 4, cores = 4,
                   return_inferencedata = True,
                   compute_convergence_checks = True, 
                   discard_tuned_samples = True)
        
 # Prediction 
 post_pred = pm.sample_posterior_predictive(trace)

Without the likelihood, the sampling takes about 18 minutes. I can’t use a different sampler from the default one, like Blackjax or NumPyro, because the model contains a discrete variable (gp_1) on which many of the continuous variables depend. The observed data and the parameters of the likelihood are 2D arrays of shape (40, 188), so I understand that the sampling might be slower. However, after 2 days, it is still at 32%. Is it normal for it to be this slow, or is there something fundamentally wrong with the model?

I’m using Windows and running the model in a conda environment created on WSL. My PyMC version is 5.15.1 and PyTensor is 2.22.1.

Any help or advice would be much appreciated.