Dirichlet Mixture model

For example, the stick-breaking:

k = 3
size = 10000
a = st.halfnorm.rvs(loc=0., scale=1, size=size)
b = st.beta.rvs(1, a, size=(k, size))
p = stick_breaking(b).T

plotting p:
image

p = []
for ia in a:
    try:
        p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
    except ZeroDivisionError:
        ia = st.halfnorm.rvs(loc=0., scale=1)
        p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
    p.append(p_tmp)

p = np.squeeze(np.asarray(p))

image

Check out this notebook for more information: https://github.com/junpenglao/Planet_Sakaar_Data_Science/blob/master/Miscellaneous/Softmax%20normal%20compare%20with%20Dirichlet.ipynb
Also some additional details you should be aware of: the stick-breaking does not guarantee that the vector is summed to 1 - it is not a problem when k is large, but could raise an error if k is small.

1 Like