So here is how I would approach it:
First, I would rewrite the explicit model, so that the observed is a flatten array - it is easier to handle shape wise:
# flatten and index data
data_flatten = np.reshape(data, data.shape[0]*data.shape[1])
data_index = np.repeat(np.arange(data.shape[0]), data.shape[1])
with pm.Model() as model:
# Global topic distribution
theta = pm.Dirichlet("theta", a=alpha)
# Word distributions for K topics
phi = pm.Dirichlet("phi", a=beta, shape=(K, V))
# Topic of documents
z = pm.Categorical("z", p=theta, shape=D)
# Words in documents
p = phi[z][data_index]
w = pm.Categorical("w", p=p, observed=data_flatten)
trace = pm.sample(1000, tune=1000)
Now, to turn this into a marginalized model, we need to not do indexing of the p
in p = phi[z]
, but keep the dimensions of the latent topics. Computationally, we want the logp to be evaluated for each latent topic (not just for the “true” latent label):
with pm.Model() as model_marg:
# Word distributions for K topics
phi = pm.Dirichlet("phi", a=beta, shape=(K, V))
# Topic of documents
z = pm.Dirichlet("z", a=alpha/2., shape=(D, K))
# Global topic distribution
theta = pm.Deterministic("theta", z.mean(axis=0))
# Words in documents
comp_dists = pm.Categorical.dist(phi)
w = pm.Mixture("w",
w=z[data_index, :],
comp_dists=comp_dists,
observed=data_flatten)
trace_marg = pm.sample(1000, tune=1000)
you can check the mixture component shape to make sure that is the case:
w.distribution.w.tag.test_value.shape
w.distribution._comp_logp(data_flatten).tag.test_value.shape
Of course, the two model is not identical, and there are inference problem like multi-modality for both of them… Mixture model is difficult unfortunately.
I put everything in a notebook here: Planet_Sakaar_Data_Science/discourse_2314.ipynb at main · junpenglao/Planet_Sakaar_Data_Science · GitHub