Incorrect step assignments with custom step function

I removed the Dirichlet transform. Here is a working snippet:

import numpy as np
import pymc as pm
from pymc.step_methods.arraystep import BlockedStep

def sample_dirichlet(c):
    gamma = np.random.gamma(c)
    p = gamma / gamma.sum(axis=-1, keepdims=True)
    return p

class ConjugateStep(BlockedStep):
    def __init__(self, var, counts: np.ndarray, concentration):
        model = pm.modelcontext(None)
        value_var = model.rvs_to_values[var]
        self.vars = [value_var]
        self.name = value_var.name
        self.counts = counts
        self.conc_prior_name = model.rvs_to_values[concentration].name

    def step(self, point: dict):
        conc_posterior = np.exp(point[self.conc_prior_name]) + self.counts
        draw = sample_dirichlet(conc_posterior)
        point[self.name] = draw
        return point, []

# Generate data
J = 10
N = 500
ncounts = 20
tau_true = 0.5
alpha = tau_true * np.ones([N, J])
p_true = sample_dirichlet(alpha)
counts = np.zeros([N, J])
for i in range(N):
    counts[i] = np.random.multinomial(ncounts, p_true[i])

# Set up model and sample
with pm.Model() as model:
    tau = pm.Exponential("tau", lam=1, initval=1.0)
    alpha = pm.Deterministic("alpha", tau * np.ones([N, J]))
    p = pm.Dirichlet("p", a=alpha, transform=None)
    step = [ConjugateStep(p, counts, tau)]
    trace = pm.sample(step=step, chains=2, cores=1, return_inferencedata=True)
3 Likes