Unpooled Dirichlet Process

I’m having trouble defining an unpooled Dirichlet process model (stick breaking). Everything I’ve tried gives an error complaining that there are inconsistent dimensions. Below is some example code I pulled from Bayesian Analysis with Python (3rd Ed.), but I’ve added a coord dictionary for weekdays. How can I add dims arguments to the model definition so that I can get an unpooled model for each day in the dataset? I suspect the shape arguments aren’t playing nice when I add dims arguments. Perhaps I need to omit dims arguments and make the shape arguments multidimensional?

K = 10

def stick_breaking(α, K):
    β = pm.Beta('β', 1., α, shape=K)
    w = β * pt.concatenate([[1.], pt.extra_ops.cumprod(1. - β)[:-1]]) + 1E-6
    return w/w.sum()

idx = pd.Categorical(tips["day"], categories=["Thurs", "Fri", "Sat", "Sun"]).codes
coords = {"days": categories, "days_flat": categories[idx]}
with pm.Model(coords=coords) as model_DP:
    α = pm.Gamma('α', 2, 1)
    w = pm.Deterministic('w', stick_breaking(α, K)) 
    means = pm.Normal('means',
                      mu=np.linspace(cs_exp.min(), cs_exp.max(), K),
                      sigma=5, shape=K,
                      transform=pm.distributions.transforms.univariate_ordered,
                     )
    
    sd = pm.HalfNormal('sd', sigma=5, shape=K)
    obs = pm.NormalMixture('obs', w, means, sigma=sd, observed=cs_exp.values)
    idata = pm.sample(random_seed=123, target_accept=0.9)

After a lot of reading and experimenting, I believe I’ve almost solved the problem I described above. But first, I want to mention that I think this PyMC documentation page could be improved. It provides a lot of great information about how broadcasting and the shape argument works. However, it doesn’t provide any information about how the dims argument interacts with the shape argument. Since I couldn’t find any information about how they interact, I decided to drop usage of the dims argument and just use the shape argument.

Below is an updated version of the example code above. I’ve tested that every line successfully allows the model to do prior predictive sampling, except the line defining the likelihood. So hopefully all but that line are well-defined. I believe my problem with the likelihood definition is that I’m improperly indexing the unpooled priors. I’m not aware of an example where the priors have more than one dimension (in linear algebra parlance, not PyMC parlance). How can I fix the likelihood definition?

import numpy as np

K = 10

def stick_breaking(α, K):
    β = pm.Beta('β', 1., α, shape=(K, n_categories))
    w = β * pt.concatenate([np.ones((1, n_categories)), pt.extra_ops.cumprod(1. - β, axis=0)[:-1]]) + 1E-6
    return w/w.sum(axis=0)

idx = pd.Categorical(tips["day"], categories=["Thurs", "Fri", "Sat", "Sun"]).codes
with pm.Model() as model_DP:
    α = pm.Gamma('α', 2, 1, shape=n_categories)
    w = pm.Deterministic('w', stick_breaking(α, K))
    means = pm.Normal('means',
                      mu=np.tile(np.linspace(cs_exp.min(), cs_exp.max(), K), (n_categories, 1)).T,
                      sigma=5*np.ones((k, n_categories)),
                      transform=pm.distributions.transforms.univariate_ordered,
                     )
    
    sd = pm.HalfNormal('sd', sigma=5, shape=(k, n_categories))
    obs = pm.NormalMixture('obs', w[:, idx], means[:, idx], sigma=sd[:, idx], observed=cs_exp.values)
    idata = pm.sample(random_seed=123, target_accept=0.9)