Maybe something like this? Warning, I threw this together quickly and may have made very dumb mistakes.
import pymc as pm
import numpy as np
colours = ["red", "orange", "yellow", "green", "purple"]
data = np.array([[3, 2, 4, 23, 4], # bag 1
[22, 1, 6, 2, 8]]) # bag 2
bags = list(range(data.shape[0]))
coords = {"colour": colours, "bag": bags}
k = len(colours)
n = len(bags)
with pm.Model(coords=coords) as skittles_model:
frac = pm.Dirichlet("frac", a=np.ones(k), dims="colour")
conc = pm.Lognormal("conc", mu=1, sigma=1)
bag_of_skittles = pm.DirichletMultinomial(
"bag_of_skittles",
n=data.sum(axis=1),
a=frac * conc,
observed=data,
dims=("bag", "colour")
)
idata = pm.sample()