In order to port a complex PyMC3 model with custom conjugate steppers involving transformed variables to v5, I am trying to update this excellent old PyMC3 example from the docs to v5.
The first issue I am running into (I am sure there will be more) is that pm.sample
doesn’t appear to set up the steppers correctly. Here is the code, a reduced version of the example:
import numpy as np
import pymc as pm
from pymc.distributions.transforms import sum_to_1
from pymc.step_methods.arraystep import BlockedStep
def sample_dirichlet(c):
gamma = np.random.gamma(c)
p = gamma / gamma.sum(axis=-1, keepdims=True)
return p
class ConjugateStep(BlockedStep):
def __init__(self, var, counts: np.ndarray, concentration):
self.vars = [var]
self.name = var.name
self.counts = counts
self.conc_prior = concentration
def step(self, point: dict):
conc_posterior = np.exp(point[self.conc_prior.transformed.name]) + self.counts
draw = sample_dirichlet(conc_posterior)
point[self.name] = sum_to_1.forward(draw)
return point, []
# Generate data
J = 10
N = 500
ncounts = 20
tau_true = 0.5
alpha = tau_true * np.ones([N, J])
p_true = sample_dirichlet(alpha)
counts = np.zeros([N, J])
for i in range(N):
counts[i] = np.random.multinomial(ncounts, p_true[i])
# Set up model and sample
with pm.Model() as model:
tau = pm.Exponential("tau", lam=1, initval=1.0)
alpha = pm.Deterministic("alpha", tau * np.ones([N, J]))
p = pm.Dirichlet("p", a=alpha)
step = [ConjugateStep(p, counts, tau)]
trace = pm.sample(step=step, chains=2, cores=1, return_inferencedata=True)
Here is the initial output before it crashes
Sequential sampling (2 chains in 1 job)
CompoundStep
>ConjugateStep: [p]
>NUTS: [tau, p]
The parameter p
shows up in both steps, which is not correct.
If I explicitly assign steppers for all parameters using step=[ConjugateStep(p, counts, tau), pm.NUTS(tau)]
, I get
Sequential sampling (2 chains in 1 job)
CompoundStep
>ConjugateStep: [p]
>NUTS: [tau]
>NUTS: [p]
with the same basic issue.
The problem seems to be somewhere in the function assign_step_methods, but my understanding of the PyMC codebase is insufficient to figure out more.
How can this problem be fixed? Any help would be appreciated!