For example, the stick-breaking:
k = 3
size = 10000
a = st.halfnorm.rvs(loc=0., scale=1, size=size)
b = st.beta.rvs(1, a, size=(k, size))
p = stick_breaking(b).T
plotting p:
p = []
for ia in a:
try:
p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
except ZeroDivisionError:
ia = st.halfnorm.rvs(loc=0., scale=1)
p_tmp = st.dirichlet.rvs(np.ones(k)*ia)
p.append(p_tmp)
p = np.squeeze(np.asarray(p))
Check out this notebook for more information: https://github.com/junpenglao/Planet_Sakaar_Data_Science/blob/master/Miscellaneous/Softmax%20normal%20compare%20with%20Dirichlet.ipynb
Also some additional details you should be aware of: the stick-breaking does not guarantee that the vector is summed to 1 - it is not a problem when k
is large, but could raise an error if k
is small.