I stumbled upon a very related issue. In pymc3 we were able to access the number of categories in the x ~ pm.Categorical via
num_categories = x.distribution.k.tolist()
I came up with a hacky way to achieve this in pymc4 (via x.owner.inputs) but I am wondering if you could tell me a better way to achieve this @ricardoV94 ?
Here’s what I’m doing
diris = [input for input in x.owner.inputs if is_dirichlet(input)]
assert len(diris) == 1, "There should be exactly 1 Dirichlet input for a categorical variable"
num_categories_common_cause = [0].shape.eval()[0]
which works, since our pm.Categorical(name, p=p) variables always take a p ~ pm.Dirichlet