# Dirichlet Mixture model

Hi, I am playing with a Dirichlet Mixture model. I am trying to cluster the Iris dataset.

In one model I generate weights \pi from a Dirichlet Process and use this to find the number of clusters.

k = 30

m = pm.Model()
with m:
alpha = pm.HalfNormal('alpha', 1, shape=1)
beta = pm.Beta('beta', 1, alpha, shape=30)
pi = pm.Deterministic('pi', dirichlet_process(beta))

sd = pm.HalfNormal('sd', sd=1, shape=k)
mu = pm.Normal('mu', mu=0, sd=sd, shape=k)
obs = pm.NormalMixture('obs', w=pi, mu=mu, sd=sd, observed=df.values[3])


This works well.

I try to recreate the same good results with the pm.Dirichlet distribution (which is also a stick-breaking method).

However if I try this. The result is nothing like the first model.

k = 30
m = pm.Model()
with m:
alpha = pm.HalfNormal('alpha', 1, shape=1)
pi = pm.Dirichlet('pi', tt.ones(30) * alpha, shape=30)

sd = pm.HalfNormal('sd', sd=1, shape=k)
mu = pm.Normal('mu', mu=0, sd=sd, shape=k)
obs = pm.NormalMixture('obs', w=pi, mu=mu, sd=sd, observed=df.values[3])


Am I using the pm.Dirichlet distribution wrong?

N.B. I am trying to do something similar as done in this tutorial.
https://docs.pymc.io/notebooks/dp_mix.html

Although Dirichlet distribution also uses a stick-breaking transformation internally, it is not the same as a Dirichlet process constructed using stick-breaking. The model logp computed from the two is different - you can also check via generating samples from the two.

For example, the stick-breaking:

k = 3
size = 10000
a = st.halfnorm.rvs(loc=0., scale=1, size=size)
b = st.beta.rvs(1, a, size=(k, size))
p = stick_breaking(b).T


plotting p:

p = []
for ia in a:
try:
p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
except ZeroDivisionError:
ia = st.halfnorm.rvs(loc=0., scale=1)
p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
p.append(p_tmp)

p = np.squeeze(np.asarray(p))


Also some additional details you should be aware of: the stick-breaking does not guarantee that the vector is summed to 1 - it is not a problem when k is large, but could raise an error if k is small.