Label switching in multivariate mixtures

I need to fit a mixture of Dirichlet distributions, but am having label switching problems in doing so. As an example, I’ve whipped up a toy data set, generated by the following process:

def create_data():
    dir1 = pm.distributions.multivariate.Dirichlet.dist(np.array([1, 5, 2]))
    dir3 = pm.distributions.multivariate.Dirichlet.dist(np.array([7, .5, 1]))
    dir2 = pm.distributions.multivariate.Dirichlet.dist(np.array([2, 3, 3]))
    data = np.concatenate((dir1.random(size=700), dir2.random(size=200), dir3.random(size=100)), axis=0)
    return data

Here is my first, unembellished attempt to fit this data:

def dirichlet(n_dim, suffix=""):
    if not isinstance(suffix, str):
        suffix = str(suffix)
    b = pm.HalfNormal("b" + suffix, sigma=10)
    a = pm.Dirichlet("a" + suffix, np.ones(n_dim))
    c = pm.Deterministic("c" + suffix, a * b)
    return pm.Dirichlet.dist(c, shape=3)

def stick_breaking(beta):
    portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
    return beta * portion_remaining

def summarize_trace(trace, n_models):
    print([trace.at[f'w[i]', 'mean'] for i in range(n_models)])
    for i in range(n_models):
        numbers = [trace.at[f"c_{i}[0]", "mean"], trace.at[f"c_{i}[1]", "mean"], trace.at[f"c_{i}[2]", "mean"]]
        print(f"c_{i}: {numbers}")
    
def estimate_model(data, n_clusters, n_features):
    with pm.Model() as model:
        alpha = pm.Gamma('alpha', 1., 1.)
        beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
        w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
        b = pm.HalfNormal("b", sigma=10, shape=n_clusters)
        a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
        c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
        obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
        trace = pm.sample(10000, tune=2000, random_seed=123)
        pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
        return pm.summary(trace)
N = 10

data = create_data()
summ = estimate_model(data, N, 3)
print_cs(summ, N)

I get as a result:

Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 69 seconds.
There were 5629 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5671 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5522 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5464 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 25% for some parameters.

w:[0.218, 0.15, 0.145, 0.115, 0.099, 0.076, 0.067, 0.045, 0.042, 0.043]
c_0: [2.188, 3.345, 2.56]
c_1: [2.574, 2.96, 2.777]
c_2: [2.442, 3.162, 2.618]
c_3: [2.604, 3.071, 2.624]
c_4: [2.494, 2.885, 2.638]
c_5: [2.711, 2.79, 2.71]
c_6: [2.586, 2.941, 2.559]
c_7: [2.636, 2.742, 2.626]
c_8: [2.566, 2.807, 2.605]
c_9: [2.641, 2.734, 2.642]

(For all of my following attempts I omit a traceplot because I’m image-restricted and they all look similar)

This clearly doesn’t reflect the data and looks like some label switching.

I’d heard that the usual solution to this problem is ordering the paramaeters. While the main object of the model, the c’s, are multidimensional and thus unordered, I tried ordering the b’s as follows:

def estimate_model_order_b(data, n_clusters, n_features):
    with pm.Model() as model:
        alpha = pm.Gamma('alpha', 1., 1.)
        beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
        w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
        b = pm.HalfNormal("b", sigma=10, shape=n_clusters, transform=tr.ordered, testval=np.linspace(1., 2., n_clusters))
        a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
        c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
        obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
        trace = pm.sample(10000, tune=2000, random_seed=123)
        pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
        return pm.summary(trace)

This did not help and I got:

Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 362 seconds.
There were 9234 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6215671333378464, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8961 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5455475257576264, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8903 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5522202860081645, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8543 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6179434162247062, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.
w: [0.421, 0.261, 0.224, 0.052, 0.012, 0.01, 0.006, 0.006, 0.004, 0.003]
c_0: [1.063, 2.746, 1.514]
c_1: [4.282, 1.787, 1.441]
c_2: [2.683, 3.16, 2.901]
c_3: [5.678, 1.867, 2.288]
c_4: [3.817, 3.53, 3.465]
c_5: [4.194, 3.651, 4.156]
c_6: [4.834, 4.157, 4.486]
c_7: [5.502, 4.863, 4.876]
c_8: [5.62, 5.585, 6.285]
c_9: [7.152, 7.379, 7.004]

I assume that this failed because the a’s were still allowed to labelswitch, which lead to the c’s doing the same. I then thought that it might be possible to simply order the c’s by their first dimension. I tried to use the potential function to do this like so:

def ordering_potential(c):
    thing = 0
    for i in range(len(c) - 1):
        thing += tt.switch(c[i + 1][0] - c[i][0] < 0, -np.inf, 0)
    return thing

def estimate_model(data, n_clusters, n_features):
    with pm.Model() as model:
        alpha = pm.Gamma('alpha', 1., 1.)
        beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
        w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
        b = pm.HalfNormal("b", sigma=10, shape=n_clusters)
        a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
        c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
        pot = pm.Potential("c_potential", ordering_potential(c))
        obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
        trace = pm.sample(10000, tune=2000, random_seed=123)
        pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
        return pm.summary(trace)

This did not work, leaving me with the following error:

/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:
Series([], )
---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 182, in _start_loop
    point, stats = self._compute_point()
  File "/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 209, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/opt/conda/lib/python3.7/site-packages/pymc3/step_methods/arraystep.py", line 263, in step
    apoint, stats = self.astep(array)
  File "/opt/conda/lib/python3.7/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 158, in astep
    raise SamplingError("Bad initial energy")
pymc3.exceptions.SamplingError: Bad initial energy
"""

The above exception was the direct cause of the following exception:

SamplingError                             Traceback (most recent call last)
SamplingError: Bad initial energy

The above exception was the direct cause of the following exception:

ParallelSamplingError                     Traceback (most recent call last)
<ipython-input-2-b6fc0d0508dc> in <module>
     96 
     97 data = create_data()
---> 98 summ = estimate_model_ordering_potential(data, N, 3)
     99 summarize_trace(summ, N)

<ipython-input-2-b6fc0d0508dc> in estimate_model_ordering_potential(data, n_clusters, n_features)
     64         pot = pm.Potential("c_potential", ordering_potential(c))
     65         obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
---> 66         trace = pm.sample(10000, tune=2000, random_seed=123)
     67         pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
     68         return pm.summary(trace)

/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, **kwargs)
    520         _print_step_hierarchy(step)
    521         try:
--> 522             trace = _mp_sample(**sample_args)
    523         except pickle.PickleError:
    524             _log.warning("Could not pickle model, sampling singlethreaded.")

/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, **kwargs)
   1415         try:
   1416             with sampler:
-> 1417                 for draw in sampler:
   1418                     trace = traces[draw.chain - chain]
   1419                     if trace.supports_sampler_stats and draw.stats is not None:

/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py in __iter__(self)
    410 
    411         while self._active:
--> 412             draw = ProcessAdapter.recv_draw(self._active)
    413             proc, is_last, draw, tuning, stats, warns = draw
    414             self._total_draws += 1

/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
    312             else:
    313                 error = RuntimeError("Chain %s failed." % proc.chain)
--> 314             raise error from old_error
    315         elif msg[0] == "writing_done":
    316             proc._readable = True

ParallelSamplingError: Bad initial energy

The model seems to have difficulty maintaining 10 models that are ordered in this way. If I crank it down to 3 models (the minimum allowable by the data generating process), the model can finish fitting without exceptions, but it produces results that don’t match the data in similar ways to the previous attempts.

How is one supposed to deal with label switching for multivariate mixture models? I haven’t been able to find a straightforward answer to this question and my own clumsy experimentation has failed. Is this just something that pymc doesn’t yet support?