How to improve ESS and R_hat for DPMM?

I made a DPMM, largely following along to the PyMC example on Dirichlet processes for density estimation, but using some simulated data.

I am getting low ESS for those parameters that correctly identify the clusters – mostly in the tens. The Rhats are also poor. How can I get the ESS higher / Rhats lower?

It looks like the issue is that the parameter estimates for the “real” clusters jump around amongst the correct values. e.g. if there were two clusters, at mean 8 and 14, then the first two mean parameters would jump between 8 and 14. So I think the step size has problems? I think that this is the “label switching” issue I have seen mentioned re clustering methods?

Here is the sample code I was using:

import pymc as pm
import numpy as np


# simulate data
rng = np.random.default_rng(seed=3)
# 3 clusters
N_per_cluster = 100
c1 = rng.normal(3, 1, N_per_cluster)
c2 = rng.normal(6, 0.5, N_per_cluster)
c3 = rng.normal(13, 2, N_per_cluster)
simdat = np.concatenate([c1, c2, c3])


# model setup
K_true = 12  # I'll pass K_minus_1 to SBW and otherwise use K_true
K_minus_1 = K_true - 1
coords = {
    "cluster": np.arange(K_true) + 1,
    "obs_num": np.arange(simdat.shape[0]) + 1
}


# model
with pm.Model(coords=coords) as model:
    # Data
    x_data = pm.Data("x_data", simdat, dims="obs_num")

    # priors
    # stick breaking
    alpha = pm.Gamma("alpha", alpha=2, beta=0.75)
    w = pm.StickBreakingWeights("w", alpha=alpha, K=K_minus_1, dims="cluster")
    # cluster parameterization
    # K mus and K sigmas; one for each cluster considered (bounded by K).
    mu = pm.Normal("mu", mu=9, sigma=6, dims="cluster")
    sigma = pm.Gamma("sigma", alpha=2, beta=1, dims="cluster")

    # clustering
    components = pm.Normal.dist(mu=mu, sigma=sigma)
    obs = pm.Mixture("obs", w=w, comp_dists=components, observed=x_data, dims="obs_num")


# sample
with model:
    idata = pm.sample()

This is indeed label switching. The typical way to address this is to use an ordering constraint (there are many examples of this on the forum).

In your case, however, there is already a “kind of” ordering constraint on the weights from the stick breaking process – the most populous group has the highest probability of being observed, but from the order constraint, must also have the lowest mean.

No idea how to square that circle. It doesn’t strike me as the type of problem with an easy solution. I know DPs are notoriously hard to work with – there are probably specialized samplers adapted for them?

Thanks for the pointer on order constraints. I guess I’ll leave this to the experts. I was trying my hand at it after reading a paper that used something called Bayesian profile regression, which combines a DPMM with a regression model.

For posterity, here are some references that might be useful:

  1. Li Y, Schofield E, Gönen M. A tutorial on Dirichlet Process mixture modeling. J Math Psychol. 2019 Aug;91:128-144. doi: 10.1016/j.jmp.2019.04.004. Epub 2019 May 20. PMID: 31217637; PMCID: PMC6583910. link
    • Includes an annotated DPMM implementation in base R
  2. Liverani, S., Hastie, D. I., Azizi, L., Papathomas, M. and Richardson, S. (2015) PReMiuM: An R package for Profile Regression Mixture Models using Dirichlet Processes. Journal of Statistical Software, 64(7), 1-30. link
    • Includes a formal description of the MCMC sampler implementation (not that I understood it).
  3. Molitor, John & Papathomas, Michail & Jerrett, Michael & Richardson, Sylvia. (2010). Bayesian profile regression with an application to the National Survey of Children’s Health. Biostatistics (Oxford, England). 11. 484-98. 10.1093/biostatistics/kxq013. link
    • A much easier read than reference #2, if interested in this method

Glancing over the references, it looks like they use a Gibbs sampler based on a Chinese Restaurant Process for sampling DPs in general (check algorithm 1 from the 2nd paper). Specifically for the profile regression, they do “metropolis-in-gibbs”, which is an old school way to say they have a BlockedStep, where certain variables are updated by gibbs, and others are updated by metropolis.

Focusing on the cluster labels, means, and covarainces, it appears the setup is conditionally conjugate for each data point’s label assignment, given all other label assignments and fixed observation noise( \sigma_y in the paper). Playing around with the sampler written in R in Appendix B of that paper (I just threw it into GPT and asked for a python version). It seems like it does not suffer from label switching, so having this as a joint sampling step over the implicated variables would be nice.

There has been some work on DPs here and here by @larryshamalama . I think another approach though, given the conditional conjugacy, would be to add a DP case to the machinery proposed in this PR by @ricardoV94, which allows “optimized sampling” by first exploiting conjugate relationships. For this, we’d have to first write the model in an ‘unmarginalized’ way:

with pm.Model(coords=coords) as model:
    x_data = pm.Data("x_data", simdat, dims="obs_num")

    alpha = pm.Gamma("alpha", alpha=2, beta=0.75)
    w = pm.StickBreakingWeights("w", alpha=alpha, K=K_minus_1, dims="cluster")
    mu = pm.Normal("mu", mu=9, sigma=6, dims="cluster")
    sigma = pm.Gamma("sigma", alpha=2, beta=1, dims="cluster")

    class_idx = pm.Categorical('class_idx', p=w, dims='cluster')
    obs = pm.Normal("obs", mu=mu[class_idx], sigma=sigma['class_idx'] observed=x_data, dims="obs_num")

We could then look for the DP in this model, defined as a:

  1. Normal likelihood, with;
  2. Mu and Sigma indexed, and;
  3. The index is a categorical variable, and;
  4. The weights of the categorical variable are stick breaking weights

This is what would signal a DP, and we could then sample all of mu, sigma, components, obs jointly with the CRP.