Dirichlet Mixture model

Hi, I am playing with a Dirichlet Mixture model. I am trying to cluster the Iris dataset.

In one model I generate weights \pi from a Dirichlet Process and use this to find the number of clusters.

k = 30 

m = pm.Model()
with m:
    alpha = pm.HalfNormal('alpha', 1, shape=1)
    beta = pm.Beta('beta', 1, alpha, shape=30)
    pi = pm.Deterministic('pi', dirichlet_process(beta))

    sd = pm.HalfNormal('sd', sd=1, shape=k)
    mu = pm.Normal('mu', mu=0, sd=sd, shape=k)
    obs = pm.NormalMixture('obs', w=pi, mu=mu, sd=sd, observed=df.values[3])

This works well.

I try to recreate the same good results with the pm.Dirichlet distribution (which is also a stick-breaking method).

However if I try this. The result is nothing like the first model.

k = 30
m = pm.Model()
with m:
    alpha = pm.HalfNormal('alpha', 1, shape=1)
    pi = pm.Dirichlet('pi', tt.ones(30) * alpha, shape=30)

    sd = pm.HalfNormal('sd', sd=1, shape=k)
    mu = pm.Normal('mu', mu=0, sd=sd, shape=k)
    obs = pm.NormalMixture('obs', w=pi, mu=mu, sd=sd, observed=df.values[3])

Am I using the pm.Dirichlet distribution wrong?

N.B. I am trying to do something similar as done in this tutorial.
https://docs.pymc.io/notebooks/dp_mix.html

Although Dirichlet distribution also uses a stick-breaking transformation internally, it is not the same as a Dirichlet process constructed using stick-breaking. The model logp computed from the two is different - you can also check via generating samples from the two.

For example, the stick-breaking:

k = 3
size = 10000
a = st.halfnorm.rvs(loc=0., scale=1, size=size)
b = st.beta.rvs(1, a, size=(k, size))
p = stick_breaking(b).T

plotting p:
image

p = []
for ia in a:
    try:
        p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
    except ZeroDivisionError:
        ia = st.halfnorm.rvs(loc=0., scale=1)
        p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
    p.append(p_tmp)

p = np.squeeze(np.asarray(p))

image

Check out this notebook for more information: https://github.com/junpenglao/Planet_Sakaar_Data_Science/blob/master/Miscellaneous/Softmax%20normal%20compare%20with%20Dirichlet.ipynb
Also some additional details you should be aware of: the stick-breaking does not guarantee that the vector is summed to 1 - it is not a problem when k is large, but could raise an error if k is small.

1 Like

Ah, I see. Very clear answer and neat notebook. Thanks for the clarification.

1 Like