I made a DPMM, largely following along to the PyMC example on Dirichlet processes for density estimation, but using some simulated data.
I am getting low ESS for those parameters that correctly identify the clusters – mostly in the tens. The Rhats are also poor. How can I get the ESS higher / Rhats lower?
It looks like the issue is that the parameter estimates for the “real” clusters jump around amongst the correct values. e.g. if there were two clusters, at mean 8 and 14, then the first two mean parameters would jump between 8 and 14. So I think the step size has problems? I think that this is the “label switching” issue I have seen mentioned re clustering methods?
Here is the sample code I was using:
import pymc as pm
import numpy as np
# simulate data
rng = np.random.default_rng(seed=3)
# 3 clusters
N_per_cluster = 100
c1 = rng.normal(3, 1, N_per_cluster)
c2 = rng.normal(6, 0.5, N_per_cluster)
c3 = rng.normal(13, 2, N_per_cluster)
simdat = np.concatenate([c1, c2, c3])
# model setup
K_true = 12 # I'll pass K_minus_1 to SBW and otherwise use K_true
K_minus_1 = K_true - 1
coords = {
"cluster": np.arange(K_true) + 1,
"obs_num": np.arange(simdat.shape[0]) + 1
}
# model
with pm.Model(coords=coords) as model:
# Data
x_data = pm.Data("x_data", simdat, dims="obs_num")
# priors
# stick breaking
alpha = pm.Gamma("alpha", alpha=2, beta=0.75)
w = pm.StickBreakingWeights("w", alpha=alpha, K=K_minus_1, dims="cluster")
# cluster parameterization
# K mus and K sigmas; one for each cluster considered (bounded by K).
mu = pm.Normal("mu", mu=9, sigma=6, dims="cluster")
sigma = pm.Gamma("sigma", alpha=2, beta=1, dims="cluster")
# clustering
components = pm.Normal.dist(mu=mu, sigma=sigma)
obs = pm.Mixture("obs", w=w, comp_dists=components, observed=x_data, dims="obs_num")
# sample
with model:
idata = pm.sample()