# Label switching in multivariate mixtures

I need to fit a mixture of Dirichlet distributions, but am having label switching problems in doing so. As an example, I’ve whipped up a toy data set, generated by the following process:

``````def create_data():
dir1 = pm.distributions.multivariate.Dirichlet.dist(np.array([1, 5, 2]))
dir3 = pm.distributions.multivariate.Dirichlet.dist(np.array([7, .5, 1]))
dir2 = pm.distributions.multivariate.Dirichlet.dist(np.array([2, 3, 3]))
data = np.concatenate((dir1.random(size=700), dir2.random(size=200), dir3.random(size=100)), axis=0)
return data
``````

Here is my first, unembellished attempt to fit this data:

``````def dirichlet(n_dim, suffix=""):
if not isinstance(suffix, str):
suffix = str(suffix)
b = pm.HalfNormal("b" + suffix, sigma=10)
a = pm.Dirichlet("a" + suffix, np.ones(n_dim))
c = pm.Deterministic("c" + suffix, a * b)
return pm.Dirichlet.dist(c, shape=3)

def stick_breaking(beta):
portion_remaining = tt.concatenate([, tt.extra_ops.cumprod(1 - beta)[:-1]])
return beta * portion_remaining

def summarize_trace(trace, n_models):
print([trace.at[f'w[i]', 'mean'] for i in range(n_models)])
for i in range(n_models):
numbers = [trace.at[f"c_{i}", "mean"], trace.at[f"c_{i}", "mean"], trace.at[f"c_{i}", "mean"]]
print(f"c_{i}: {numbers}")

def estimate_model(data, n_clusters, n_features):
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
b = pm.HalfNormal("b", sigma=10, shape=n_clusters)
a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
trace = pm.sample(10000, tune=2000, random_seed=123)
pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
return pm.summary(trace)
N = 10

data = create_data()
summ = estimate_model(data, N, 3)
print_cs(summ, N)
``````

I get as a result:

``````Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 69 seconds.
There were 5629 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5671 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5522 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5464 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 25% for some parameters.

w:[0.218, 0.15, 0.145, 0.115, 0.099, 0.076, 0.067, 0.045, 0.042, 0.043]
c_0: [2.188, 3.345, 2.56]
c_1: [2.574, 2.96, 2.777]
c_2: [2.442, 3.162, 2.618]
c_3: [2.604, 3.071, 2.624]
c_4: [2.494, 2.885, 2.638]
c_5: [2.711, 2.79, 2.71]
c_6: [2.586, 2.941, 2.559]
c_7: [2.636, 2.742, 2.626]
c_8: [2.566, 2.807, 2.605]
c_9: [2.641, 2.734, 2.642]
``````

(For all of my following attempts I omit a traceplot because I’m image-restricted and they all look similar)

This clearly doesn’t reflect the data and looks like some label switching.

I’d heard that the usual solution to this problem is ordering the paramaeters. While the main object of the model, the c’s, are multidimensional and thus unordered, I tried ordering the b’s as follows:

``````def estimate_model_order_b(data, n_clusters, n_features):
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
b = pm.HalfNormal("b", sigma=10, shape=n_clusters, transform=tr.ordered, testval=np.linspace(1., 2., n_clusters))
a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
trace = pm.sample(10000, tune=2000, random_seed=123)
pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
return pm.summary(trace)
``````

This did not help and I got:

``````Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 362 seconds.
There were 9234 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6215671333378464, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8961 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5455475257576264, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8903 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5522202860081645, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8543 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6179434162247062, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.
w: [0.421, 0.261, 0.224, 0.052, 0.012, 0.01, 0.006, 0.006, 0.004, 0.003]
c_0: [1.063, 2.746, 1.514]
c_1: [4.282, 1.787, 1.441]
c_2: [2.683, 3.16, 2.901]
c_3: [5.678, 1.867, 2.288]
c_4: [3.817, 3.53, 3.465]
c_5: [4.194, 3.651, 4.156]
c_6: [4.834, 4.157, 4.486]
c_7: [5.502, 4.863, 4.876]
c_8: [5.62, 5.585, 6.285]
c_9: [7.152, 7.379, 7.004]
``````

I assume that this failed because the a’s were still allowed to labelswitch, which lead to the c’s doing the same. I then thought that it might be possible to simply order the c’s by their first dimension. I tried to use the potential function to do this like so:

``````def ordering_potential(c):
thing = 0
for i in range(len(c) - 1):
thing += tt.switch(c[i + 1] - c[i] < 0, -np.inf, 0)
return thing

def estimate_model(data, n_clusters, n_features):
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
b = pm.HalfNormal("b", sigma=10, shape=n_clusters)
a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
pot = pm.Potential("c_potential", ordering_potential(c))
obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
trace = pm.sample(10000, tune=2000, random_seed=123)
pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
return pm.summary(trace)
``````

This did not work, leaving me with the following error:

``````/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:
Series([], )
---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 182, in _start_loop
point, stats = self._compute_point()
File "/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 209, in _compute_point
point, stats = self._step_method.step(self._point)
File "/opt/conda/lib/python3.7/site-packages/pymc3/step_methods/arraystep.py", line 263, in step
apoint, stats = self.astep(array)
File "/opt/conda/lib/python3.7/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 158, in astep
"""

The above exception was the direct cause of the following exception:

SamplingError                             Traceback (most recent call last)

The above exception was the direct cause of the following exception:

ParallelSamplingError                     Traceback (most recent call last)
<ipython-input-2-b6fc0d0508dc> in <module>
96
97 data = create_data()
---> 98 summ = estimate_model_ordering_potential(data, N, 3)
99 summarize_trace(summ, N)

<ipython-input-2-b6fc0d0508dc> in estimate_model_ordering_potential(data, n_clusters, n_features)
64         pot = pm.Potential("c_potential", ordering_potential(c))
65         obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
---> 66         trace = pm.sample(10000, tune=2000, random_seed=123)
67         pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
68         return pm.summary(trace)

/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, **kwargs)
520         _print_step_hierarchy(step)
521         try:
--> 522             trace = _mp_sample(**sample_args)
523         except pickle.PickleError:
524             _log.warning("Could not pickle model, sampling singlethreaded.")

/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, **kwargs)
1415         try:
1416             with sampler:
-> 1417                 for draw in sampler:
1418                     trace = traces[draw.chain - chain]
1419                     if trace.supports_sampler_stats and draw.stats is not None:

/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py in __iter__(self)
410
411         while self._active:
413             proc, is_last, draw, tuning, stats, warns = draw
414             self._total_draws += 1

/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
312             else:
313                 error = RuntimeError("Chain %s failed." % proc.chain)
--> 314             raise error from old_error
315         elif msg == "writing_done":