Issues with multivariate GMM model and label switching

I want to create a mixture of multivariate (2D) normals, which I am achieving via the pymc.Mixture class. I can get the model to sample, and sometimes I get good results, other times I have issues with label switching (the label switching behavior is seemingly random).

To address the label switching issue, I am trying to apply an ordered transformation to the mixture weights. However, this is creating errors upon initialization. Below is the code with the transformation

coords = {"axis": ["x1", "x2"], "components": [1, 2]}
# create the model
with pm.Model(coords=coords) as model:
    weights = pm.Dirichlet("w", [1, 1], dims="components", transform=pm.distributions.transforms.ordered, initval=[0.2, 0.8])
    components = []
    for i in range(2):
        chol, corr, stds = pm.LKJCholeskyCov(
            "chol{}".format(i), n=2, eta=2.0, sd_dist=pm.Exponential.dist(1.0, shape=2)
        mu = pm.Normal("mu{}".format(i), 0.0, sigma=4.0, dims="axis")

        components.append(pm.MvNormal.dist(mu, chol=chol))
    obs = pm.Mixture("obs", w=weights, comp_dists=components, observed=sample)

When I try to sample the model I get the following error:

SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'w_ordered__': array([ 1.13887493, -0.77774487]), 'chol0_cholesky-cov-packed__': array([-0.40933986,  0.89455485, -0.64019363]), 'mu0': array([-0.68340835,  0.93668752]), 'chol1_cholesky-cov-packed__': array([ 0.5663125 ,  0.15188474, -0.20971411]), 'mu1': array([-0.16076354, -0.39587404])}

Logp initial evaluation results:
{'w': -inf, 'chol0': -4.54, 'mu0': -4.65, 'chol1': -1.8, 'mu1': -4.62, 'obs': -inf}
You can call `model.debug()` for more details.

For some reason applying the ordered transform to the weights variable is creating infinity values… but removing that transform there is no issue (except the label switching).

Can someone explain what is wrong?

ordered undoes the default simplex transform that Dirichlet variables have. You should use a Chain transform that includes both (this will be done by default under the hood in the next release of PyMC).

This however may not help much with multimodality. Usually it’s done on the component parameters. (e.g., ordered means or sigma)

This indeed solved the issue, and indeed ordering the weights is not sufficient to break the multi-modality :upside_down_face: