Annotators models: Multinomial model example

Hi there,

I am trying to replicate in pymc the annotators.py example originally written in numpyro.

After some trial and error I came up with the following implementation, which runs but the inference is really unstable. I think I do not have a clear mental model of how to reproduce plate (i.e., conditional independent sampling) in pymc.

Any suggestion is very much appreciated. Thanks a lot in advance for your help.

Best,
Pietro

P.S. cross-posted on github discussions.

import numpy as np
import pymc3 as pm
import arviz as az

# create data
def get_data():
    """
    :return: a tuple of annotator indices and class indices. The first term has shape
        `num_positions` whose entries take values from `0` to `num_annotators - 1`.
        The second term has shape `num_items x num_positions` whose entries take values
        from `0` to `num_classes - 1`.
    """
    positions = np.array([1, 1, 1, 2, 3, 4, 5])
    annotations = np.array(
        [
            [1, 1, 1, 1, 1, 1, 1],
            [3, 3, 3, 4, 3, 3, 4],
            [1, 1, 2, 2, 1, 2, 2],
            [2, 2, 2, 3, 1, 2, 1],
            [2, 2, 2, 3, 2, 2, 2],
            [2, 2, 2, 3, 3, 2, 2],
        ]
    )
    # we minus 1 because in Python, the first index is 0
    return positions - 1, annotations - 1

# load data and prepare
data = get_data()
positions, annotations = data
num_classes = np.unique(annotations).size
num_items, num_positions = annotations.shape

# multinomial model
with pm.Model() as multinomial_model:
    zeta = pm.Dirichlet("zeta", np.ones(num_classes), shape=(num_classes, num_classes))
    pi = pm.Dirichlet("pi", np.ones(num_classes))
    
    c = pm.Categorical("c", pi, shape=(num_items,))
    y = pm.Categorical("y", np.repeat(zeta[c, None], num_positions), observed=annotations)

    trace = pm.sample(10_000, tune=2_000, cores=4, chains=4, return_inferencedata=True)

az.summary(trace)