I need to fit a mixture of Dirichlet distributions, but am having label switching problems in doing so. As an example, I’ve whipped up a toy data set, generated by the following process:
def create_data():
dir1 = pm.distributions.multivariate.Dirichlet.dist(np.array([1, 5, 2]))
dir3 = pm.distributions.multivariate.Dirichlet.dist(np.array([7, .5, 1]))
dir2 = pm.distributions.multivariate.Dirichlet.dist(np.array([2, 3, 3]))
data = np.concatenate((dir1.random(size=700), dir2.random(size=200), dir3.random(size=100)), axis=0)
return data
Here is my first, unembellished attempt to fit this data:
def dirichlet(n_dim, suffix=""):
if not isinstance(suffix, str):
suffix = str(suffix)
b = pm.HalfNormal("b" + suffix, sigma=10)
a = pm.Dirichlet("a" + suffix, np.ones(n_dim))
c = pm.Deterministic("c" + suffix, a * b)
return pm.Dirichlet.dist(c, shape=3)
def stick_breaking(beta):
portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
return beta * portion_remaining
def summarize_trace(trace, n_models):
print([trace.at[f'w[i]', 'mean'] for i in range(n_models)])
for i in range(n_models):
numbers = [trace.at[f"c_{i}[0]", "mean"], trace.at[f"c_{i}[1]", "mean"], trace.at[f"c_{i}[2]", "mean"]]
print(f"c_{i}: {numbers}")
def estimate_model(data, n_clusters, n_features):
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
b = pm.HalfNormal("b", sigma=10, shape=n_clusters)
a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
trace = pm.sample(10000, tune=2000, random_seed=123)
pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
return pm.summary(trace)
N = 10
data = create_data()
summ = estimate_model(data, N, 3)
print_cs(summ, N)
I get as a result:
Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 69 seconds.
There were 5629 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5671 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5522 divergences after tuning. Increase `target_accept` or reparameterize.
There were 5464 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 25% for some parameters.
w:[0.218, 0.15, 0.145, 0.115, 0.099, 0.076, 0.067, 0.045, 0.042, 0.043]
c_0: [2.188, 3.345, 2.56]
c_1: [2.574, 2.96, 2.777]
c_2: [2.442, 3.162, 2.618]
c_3: [2.604, 3.071, 2.624]
c_4: [2.494, 2.885, 2.638]
c_5: [2.711, 2.79, 2.71]
c_6: [2.586, 2.941, 2.559]
c_7: [2.636, 2.742, 2.626]
c_8: [2.566, 2.807, 2.605]
c_9: [2.641, 2.734, 2.642]
(For all of my following attempts I omit a traceplot because I’m image-restricted and they all look similar)
This clearly doesn’t reflect the data and looks like some label switching.
I’d heard that the usual solution to this problem is ordering the paramaeters. While the main object of the model, the c’s, are multidimensional and thus unordered, I tried ordering the b’s as follows:
def estimate_model_order_b(data, n_clusters, n_features):
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
b = pm.HalfNormal("b", sigma=10, shape=n_clusters, transform=tr.ordered, testval=np.linspace(1., 2., n_clusters))
a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
trace = pm.sample(10000, tune=2000, random_seed=123)
pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
return pm.summary(trace)
This did not help and I got:
Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 362 seconds.
There were 9234 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6215671333378464, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8961 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5455475257576264, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8903 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.5522202860081645, but should be close to 0.8. Try to increase the number of tuning steps.
There were 8543 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6179434162247062, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.
w: [0.421, 0.261, 0.224, 0.052, 0.012, 0.01, 0.006, 0.006, 0.004, 0.003]
c_0: [1.063, 2.746, 1.514]
c_1: [4.282, 1.787, 1.441]
c_2: [2.683, 3.16, 2.901]
c_3: [5.678, 1.867, 2.288]
c_4: [3.817, 3.53, 3.465]
c_5: [4.194, 3.651, 4.156]
c_6: [4.834, 4.157, 4.486]
c_7: [5.502, 4.863, 4.876]
c_8: [5.62, 5.585, 6.285]
c_9: [7.152, 7.379, 7.004]
I assume that this failed because the a’s were still allowed to labelswitch, which lead to the c’s doing the same. I then thought that it might be possible to simply order the c’s by their first dimension. I tried to use the potential function to do this like so:
def ordering_potential(c):
thing = 0
for i in range(len(c) - 1):
thing += tt.switch(c[i + 1][0] - c[i][0] < 0, -np.inf, 0)
return thing
def estimate_model(data, n_clusters, n_features):
with pm.Model() as model:
alpha = pm.Gamma('alpha', 1., 1.)
beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
b = pm.HalfNormal("b", sigma=10, shape=n_clusters)
a = pm.Dirichlet("a", np.ones(n_features), shape=[n_clusters, n_features])
c = [pm.Deterministic(f"c_{k}", a[k] * b[k]) for k in range(n_clusters)]
pot = pm.Potential("c_potential", ordering_potential(c))
obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
trace = pm.sample(10000, tune=2000, random_seed=123)
pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
return pm.summary(trace)
This did not work, leaving me with the following error:
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/opt/conda/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:
Series([], )
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 182, in _start_loop
point, stats = self._compute_point()
File "/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py", line 209, in _compute_point
point, stats = self._step_method.step(self._point)
File "/opt/conda/lib/python3.7/site-packages/pymc3/step_methods/arraystep.py", line 263, in step
apoint, stats = self.astep(array)
File "/opt/conda/lib/python3.7/site-packages/pymc3/step_methods/hmc/base_hmc.py", line 158, in astep
raise SamplingError("Bad initial energy")
pymc3.exceptions.SamplingError: Bad initial energy
"""
The above exception was the direct cause of the following exception:
SamplingError Traceback (most recent call last)
SamplingError: Bad initial energy
The above exception was the direct cause of the following exception:
ParallelSamplingError Traceback (most recent call last)
<ipython-input-2-b6fc0d0508dc> in <module>
96
97 data = create_data()
---> 98 summ = estimate_model_ordering_potential(data, N, 3)
99 summarize_trace(summ, N)
<ipython-input-2-b6fc0d0508dc> in estimate_model_ordering_potential(data, n_clusters, n_features)
64 pot = pm.Potential("c_potential", ordering_potential(c))
65 obs = pm.Mixture('obs', w, [pm.Dirichlet.dist(c[k], shape=3) for k in range(n_clusters)], observed=data)
---> 66 trace = pm.sample(10000, tune=2000, random_seed=123)
67 pm.traceplot(trace, ["w", "c_0", "c_1", "c_2"])
68 return pm.summary(trace)
/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, **kwargs)
520 _print_step_hierarchy(step)
521 try:
--> 522 trace = _mp_sample(**sample_args)
523 except pickle.PickleError:
524 _log.warning("Could not pickle model, sampling singlethreaded.")
/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, **kwargs)
1415 try:
1416 with sampler:
-> 1417 for draw in sampler:
1418 trace = traces[draw.chain - chain]
1419 if trace.supports_sampler_stats and draw.stats is not None:
/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py in __iter__(self)
410
411 while self._active:
--> 412 draw = ProcessAdapter.recv_draw(self._active)
413 proc, is_last, draw, tuning, stats, warns = draw
414 self._total_draws += 1
/opt/conda/lib/python3.7/site-packages/pymc3/parallel_sampling.py in recv_draw(processes, timeout)
312 else:
313 error = RuntimeError("Chain %s failed." % proc.chain)
--> 314 raise error from old_error
315 elif msg[0] == "writing_done":
316 proc._readable = True
ParallelSamplingError: Bad initial energy
The model seems to have difficulty maintaining 10 models that are ordered in this way. If I crank it down to 3 models (the minimum allowable by the data generating process), the model can finish fitting without exceptions, but it produces results that don’t match the data in similar ways to the previous attempts.
How is one supposed to deal with label switching for multivariate mixture models? I haven’t been able to find a straightforward answer to this question and my own clumsy experimentation has failed. Is this just something that pymc doesn’t yet support?