Help! DM model not behaving as expected when sampling

Maybe something like this? Warning, I threw this together quickly and may have made very dumb mistakes.

import pymc as pm
import numpy as np

colours = ["red", "orange", "yellow", "green", "purple"]
data = np.array([[3, 2, 4, 23, 4],  # bag 1
                 [22, 1, 6, 2, 8]]) # bag 2
bags = list(range(data.shape[0]))
coords = {"colour": colours, "bag": bags}

k = len(colours)
n = len(bags)


with pm.Model(coords=coords) as skittles_model:

    frac = pm.Dirichlet("frac", a=np.ones(k), dims="colour")
    conc = pm.Lognormal("conc", mu=1, sigma=1)
    
    bag_of_skittles = pm.DirichletMultinomial(
        "bag_of_skittles", 
        n=data.sum(axis=1), 
        a=frac * conc, 
        observed=data, 
        dims=("bag", "colour")
    )
    idata = pm.sample()
2 Likes