In order to port a complex PyMC3 model with custom conjugate steppers involving transformed variables to v5, I am trying to update this excellent old PyMC3 example from the docs to v5.
The first issue I am running into (I am sure there will be more) is that
pm.sample doesn’t appear to set up the steppers correctly. Here is the code, a reduced version of the example:
import numpy as np import pymc as pm from pymc.distributions.transforms import sum_to_1 from pymc.step_methods.arraystep import BlockedStep def sample_dirichlet(c): gamma = np.random.gamma(c) p = gamma / gamma.sum(axis=-1, keepdims=True) return p class ConjugateStep(BlockedStep): def __init__(self, var, counts: np.ndarray, concentration): self.vars = [var] self.name = var.name self.counts = counts self.conc_prior = concentration def step(self, point: dict): conc_posterior = np.exp(point[self.conc_prior.transformed.name]) + self.counts draw = sample_dirichlet(conc_posterior) point[self.name] = sum_to_1.forward(draw) return point,  # Generate data J = 10 N = 500 ncounts = 20 tau_true = 0.5 alpha = tau_true * np.ones([N, J]) p_true = sample_dirichlet(alpha) counts = np.zeros([N, J]) for i in range(N): counts[i] = np.random.multinomial(ncounts, p_true[i]) # Set up model and sample with pm.Model() as model: tau = pm.Exponential("tau", lam=1, initval=1.0) alpha = pm.Deterministic("alpha", tau * np.ones([N, J])) p = pm.Dirichlet("p", a=alpha) step = [ConjugateStep(p, counts, tau)] trace = pm.sample(step=step, chains=2, cores=1, return_inferencedata=True)
Here is the initial output before it crashes
Sequential sampling (2 chains in 1 job) CompoundStep >ConjugateStep: [p] >NUTS: [tau, p]
p shows up in both steps, which is not correct.
If I explicitly assign steppers for all parameters using
step=[ConjugateStep(p, counts, tau), pm.NUTS(tau)], I get
Sequential sampling (2 chains in 1 job) CompoundStep >ConjugateStep: [p] >NUTS: [tau] >NUTS: [p]
with the same basic issue.
The problem seems to be somewhere in the function assign_step_methods, but my understanding of the PyMC codebase is insufficient to figure out more.
How can this problem be fixed? Any help would be appreciated!