I removed the Dirichlet transform. Here is a working snippet:
import numpy as np
import pymc as pm
from pymc.step_methods.arraystep import BlockedStep
def sample_dirichlet(c):
gamma = np.random.gamma(c)
p = gamma / gamma.sum(axis=-1, keepdims=True)
return p
class ConjugateStep(BlockedStep):
def __init__(self, var, counts: np.ndarray, concentration):
model = pm.modelcontext(None)
value_var = model.rvs_to_values[var]
self.vars = [value_var]
self.name = value_var.name
self.counts = counts
self.conc_prior_name = model.rvs_to_values[concentration].name
def step(self, point: dict):
conc_posterior = np.exp(point[self.conc_prior_name]) + self.counts
draw = sample_dirichlet(conc_posterior)
point[self.name] = draw
return point, []
# Generate data
J = 10
N = 500
ncounts = 20
tau_true = 0.5
alpha = tau_true * np.ones([N, J])
p_true = sample_dirichlet(alpha)
counts = np.zeros([N, J])
for i in range(N):
counts[i] = np.random.multinomial(ncounts, p_true[i])
# Set up model and sample
with pm.Model() as model:
tau = pm.Exponential("tau", lam=1, initval=1.0)
alpha = pm.Deterministic("alpha", tau * np.ones([N, J]))
p = pm.Dirichlet("p", a=alpha, transform=None)
step = [ConjugateStep(p, counts, tau)]
trace = pm.sample(step=step, chains=2, cores=1, return_inferencedata=True)