Categorical is not implemented to handle high dimensional p
, so you need to carefully validate the shape and reshape the input parameters:
alpha = np.ones((1, K))
beta = np.ones((1, V))
with pm.Model() as model:
thetas = pm.Dirichlet("thetas", a=alpha, shape=(D, K))
phis = pm.Dirichlet("phis", a=beta, shape=(K, V))
z = pm.Categorical("zx", p=thetas, shape=(W, D))
w = pm.Categorical("wx",
p=t.reshape(phis[z], (D*W, V)),
observed=data.reshape(D*W))
Something like this seems to work, but you should doubt check to make sure the shape and reshape is behaving as intended.