Multivariate categorical with observed data

As a side note, I can also parameterize the Dirichlet by setting the shape directly, which I thought should lead to a variable “prior” of shape (2,2)

with pm.Model() as model:
    observed = [[0,1,1],[1,0,1]]
    prior = pm.Dirichlet("prior", a=np.ones(2), shape=(2,))
    posterior = pm.Categorical("posterior", prior, observed=observed)
    trace = pm.sample()

This code runs and does not throw an error.
However, if I look in the trace object, I see that the prior variable is actually of shape (1,2), not (2,2)