Difficulties fitting a mixture of dirichlet distributions

I’m trying to fit a mixture of dirichlet distributions but am running into difficulties. While I don’t get any errors, the fits that I get are obviously incorrect. Here is my code:

%pylab inline
import numpy as np
import pymc3 as pm
import theano.tensor as tt
import pandas as pd
import random
import math


def create_data():
    dir1 = pm.distributions.multivariate.Dirichlet.dist(np.array([1, 5, 2]))
    dir2 = pm.distributions.multivariate.Dirichlet.dist(np.array([7, .5, 1]))
    dir3 = pm.distributions.multivariate.Dirichlet.dist(np.array([2, 3, 3]))
    data = np.concatenate((dir1.random(size=700), dir1.random(size=200), dir1.random(size=100)), axis=0)
    return data

def dirichlet(n_dim, suffix=""):
    if not isinstance(suffix, str):
        suffix = str(suffix)
    b = pm.HalfNormal("b" + suffix, sigma=10)
    a = pm.Dirichlet("a" + suffix, np.ones(n_dim))
    c = pm.Deterministic("c" + suffix, a * b)
    return pm.Dirichlet.dist(c, shape=3)

def stick_breaking(beta):
    portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
    return beta * portion_remaining
    
def estimate_model(data, n_clusters, n_features):
    with pm.Model() as model:
        alpha = pm.Gamma('alpha', 1., 1.)
        beta = pm.Beta('beta', 1, alpha, shape=n_clusters)
        w = pm.Dirichlet('w', stick_breaking(beta), shape=n_clusters)
        obs = pm.Mixture('obs', w, [dirichlet(3, k) for k in range(n_clusters)], observed=data)

        trace = pm.sample(50000, tune=10000)
        pm.traceplot(trace, ["w", "a0", "b0", "c0", "a1", "b1", "c1"])
        return pm.summary(trace)

data = create_data()
summ = estimate_model(data, 10, 3)
print(summ)

To match the data generating process, I should end up with 3 w’s with significant weight, at about .7, .2, and .1, and the c’s should reflect the vectors in the create_data method. I never get anything close to this result.

Any suggestions?

If it helps, here are the warnings I get when running:

Sampling 4 chains for 2_000 tune and 10_000 draw iterations (8_000 + 40_000 draws total) took 250 seconds.
There were 7001 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.7014387282091336, but should be close to 0.8. Try to increase the number of tuning steps.
There were 7056 divergences after tuning. Increase `target_accept` or reparameterize.
There were 6769 divergences after tuning. Increase `target_accept` or reparameterize.
There were 6957 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6761237144949118, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.
The estimated number of effective samples is smaller than 200 for some parameters.

My traceplot looks like:

I have tried looking in previous discourse threads for answers, but have been unable to find any that address my situation. Is there any additional information I need to provide?

Hi @72nd,
Dirichlet mixtures are quite hard to fit. I’m not well versed in it, but I’m guessing you have a “label-switching” problem: your cs seem to be multimodal. A usual fix is to constrain the cs to be ordered, IIRC.
You can find a good introduction to these models in @aloctavodia’s book, Bayesian Analysis with Python.
Hope this helps :vulcan_salute:

1 Like

@AlexAndorra, thank you so much for replying.
I’d run across the ordering solution in the past, but I don’t think I can apply it to multivariate data because there isn’t an ordering on tuples. Per your suggestion, I tried ordering the b’s, since they can be ordered. It did not meaningfully change the fit behavior, probably because the a’s kept label switching. It strikes me as plausible that ordering the a’s by their first component might fix this problem, but I don’t know how to do that. Is there some way to use an ordering transform on multidimensional data that I’m not aware of?

I’m sorry, I don’t know much about this type of models yet :confused:
Maybe other people here will be able to help you. I also hope you’ll be able to find something useful in Osvaldo’s book.
Good luck with this, and sorry I couldn’t help more!