Typically, if you are getting better results with a “simpler” method (MAP, metropolis, etc), your model has some kind of mis-specification. One of the best and most frustrating things about NUTS is that it fails very loudly.
I’m a beginner when it comes to mixture models, but I spent some time playing with your model. As written, you end up with really nonsense results in the non-sorted dimension. After some consultation with @ricardoV94 , I think it’s really important that you use a mixture of Multivariate Normals, rather than a mixture of Normals. My intuition (and likely yours) was that these are equivalent, but in the context of a mixture model they’re not. Using a mixture all Normals is like a mixture of 10 exchangable components, whereas a mixture of 5 MultivariateNormals encodes an unbreakable connection between each (x, y) pair.
After that change things got much more sensible. To get it across the finish line I had to tinker with priors and sampling options. I needed quite tight sigma priors to makes sure the distributions didn’t “bleed” into each other. I also got better results with the pymc NUTS implementation than jax, but I don’t have a good explanation for that. My hypothesis is that the uni-modal nature of ADVI pushed the mixture weights towards 1/n
, which ends up being the right answer. Could also experiment with setting the concentration parameter higher. Without using the ADVI initialization, the sampler often found sparse solutions, which we really don’t want (in this problem anyway).
Here’s the data generation:
n_clusters = 5
data, labels = make_blobs(n_samples=1000, centers=n_clusters, random_state=10)
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
plt.scatter(*scaled_data.T, c=labels)
And the model:
coords={"cluster": np.arange(n_clusters),
"obs_id": np.arange(data.shape[0]),
"coord":['x', 'y']}
# When you use the ordered transform, the initial values need to be
# monotonically increasing
sorted_initvals = np.linspace(-2, 2, 5)
with pm.Model(coords=coords) as model:
# Use alpha > 1 to prevent the model from finding sparse solutions -- we know
# all 5 clusters should be represented in the posterior
w = pm.Dirichlet("w", np.full(n_clusters, 10), dims=['cluster'])
# Mean component
x_coord = pm.Normal("x_coord", sigma=1, dims=["cluster"],
transform=pm.distributions.transforms.ordered,
initval=sorted_initvals)
y_coord = pm.Normal('y_coord', sigma=1, dims=['cluster'])
centroids = pm.Deterministic('centroids', pt.concatenate([x_coord[None], y_coord[None]]).T,
dims=['cluster', 'coord'])
# Diagonal covariances. Could also model the full covariance matrix, but I didn't try.
sigma = pm.HalfNormal('sigma', sigma=1, dims=['cluster', 'coord'])
covs = [pt.diag(sigma[i]) for i in range(n_clusters)]
# Define the mixture
components = [pm.MvNormal.dist(mu=centroids[i], cov=covs[i]) for i in range(n_clusters)]
y_hat = pm.Mixture("y_hat",
w,
components,
observed=scaled_data,
dims=["obs_id", 'coord'])
idata = pm.sample(init='advi+adapt_diag')
Results:
If you’re interested in the probability of each point, you can use a predictive model to go back and compute that:
with pm.Model(coords=coords) as likelihood_model:
w = pm.Flat("w", dims=['cluster'])
x_coord = pm.Flat("x_coord", dims=["cluster"])
y_coord = pm.Flat('y_coord', dims=['cluster'])
centroids = pm.Deterministic('centroids', pt.concatenate([x_coord[None], y_coord[None]]).T,
dims=['cluster', 'coord'])
sigma = pm.Flat('sigma', dims=['cluster', 'coord'])
covs = [pt.diag(sigma[i]) for i in range(n_clusters)]
components = [pm.MvNormal(f'component_{i}', mu=centroids[i], cov=covs[i]) for i in range(n_clusters)]
p_component = pm.Deterministic('p_component', pt.concatenate([pm.logp(d, scaled_data)[:, None] for d in components],
axis=-1),
dims=['obs_id', 'cluster'])
label_hat = pm.Deterministic('label_hat', pt.argmax(p_component, axis=-1), dims=['obs_id'])
idata_pred = pm.sample_posterior_predictive(idata, var_names=['p_component', 'label_hat'])