Hierarchical model fails to infer with categorical distributions

Hello folks, I’m working with a hierarchical generative model with three variables which we’d like to define as follows:

Z ~ Bern(0.5)
Y ~ Categorical(p) where if Z=0, p=[1, 0, 0] and if Z=1, p=[0, 1/2, 1/2]
X ~ MultivariateNormal(mu, Sigma) where Sigma is the identity matrix and mu is set depending on Y, with mu=[0, 0], [0,2], [2,0] if Y = 0, 1, 2 respectively.

I’ve implemented this model in PyMC however testing it by generating some synthetic samples of X gives some strange results - generating samples of X centered at [0, 2] and [2, 0] leads to the model inferring the correct Z and Y states (1, 1 and 1, 2 respectively) however generating samples of X centered at [0, 0] which should result in Z, Y = 0, 0, the posterior values instead are Z, Y = 1, 1. The model also fails for any intermediate values of the mean of X, you’d expect samples of X generated at [0, 1] to result in posterior Y samples to be roughly evenly split between 0 and 1 yet all the samples are drawn as Y=1. I’ve tried implementing the model two ways:

def gen_X_samples(X_mean, precision, nsamples): #generate some synthetic data
  sd = np.sqrt(1/precision)
  X_covmat = np.array([[sd, 0], [0, sd]])
  X_samples = np.random.multivariate_normal(X_mean, X_covmat, size=nsamples)
  return X_samples
X_samples = gen_X_samples([0, 0], 1, 1000)

with pm.Model() as model: #first way
  Z = pm.Bernoulli('Z', p=0.5)
  p_Y = pm.math.switch(pm.math.eq(Z, 1), [0, 1/2, 1/2], [1, 0, 0])
  Y = pm.Categorical('Y', p=p_Y)
  mu_list = aesara.shared(np.array([[0,0], [0,2], [2,0]]))
  set_mu = pm.Deterministic('set_mu', mu_list[Y])
  X_obs = pm.Normal('X_obs', mu=set_mu, tau=1, shape=2, observed=X_samples)
  trace = pm.sample()

with pm.Model() as model: #second way
  Z = pm.Bernoulli('Z', p=0.5)
  p_Y0 = pm.Deterministic('p_Y0', 1-Z)
  p_Y1 = pm.Deterministic('p_Y1', Z/2)
  p_Y2 = pm.Deterministic('p_Y2', Z/2)
  Y_prior = aesara.tensor.stack(p_Y0, p_Y1, p_Y2)
  Y = pm.Categorical('Y', p=Y_prior)
  set_mu = pm.Deterministic('set_mu', ar.tensor.stack(Y*(Y-1), 2*Y*(2-Y)))
  X_obs = pm.Normal('X_obs', mu=set_mu, tau=1, shape=2, observed=X_samples)
  trace = pm.sample()

I thought the problem might be the indexing in the first model so tried the second way to introduce all the dependencies between variables using only arithmetic, but the result is the same. Sampling from the prior using sample_prior_predictive seems to work correctly with all the dependencies between variables, so I’m lost as to why the model doesn’t seem to be able to draw posterior samples of Z=0 when given data.