Marginalizing out a categorical variable

Hi everyone!

I have no previous experience with marginalizing discrete variables out of the model so that PyMC3 can use NUTS, and I’m not seeing how to generalize the cases that have been discussed in the literature to my case - I am probably missing something.

Here is the model I want to work with:

p(\sigma,\alpha_0,\alpha_1,\zeta \mid \textbf{e}, \textbf{c}) \propto p(\sigma)p(\alpha_0)p(\alpha_1)p(\zeta)\prod_{i=1}^{len(\textbf{e})} p(\textbf{e}_i \mid \alpha_0, \alpha_1, \zeta, \textbf{c}_i, \sigma)

Where:

  • \textbf{e} is a vector whose i th element is the cost for observation i (observed)
  • \textbf{c} is a vector whose i th element is the category for observation i (observed)
  • \sigma, \alpha_0, \alpha_1 are continuous (unobserved)
  • \zeta is discrete (this is the problematic one) taking one of m possible values (unobserved)

In a different notation and somewhat more specifically:

\begin{align} \sigma & \sim \mathcal{HN}(1) \\ \alpha_0 & \sim \mathcal{N}(0,1) \\ \alpha_1 & \sim \mathcal{N}(0,1) \\ \zeta & \sim \textrm{CAT} (1/m, \cdots, 1/m) \\ \textbf{e}_i & \sim \mathcal{N}(\mu=\alpha_0+\alpha_1 L \left[\zeta,c_i \right], \sigma=\sigma) \end{align}

Where L is a (known) matrix whose (i,j) th element is some feature of the combination of \zeta and a category that the cost depends on (among other things).

Ultimately, I am interested in finding the probability of each possible \zeta being the true one.

Here’s the code to produce some fake data:

    sigma = 0.1
    a_0, a_1 = 0, 1
    z = 3 
    L = np.array([
        [1, 2, 1],
        [3, 2, 1],
        [1, 4, 5],
        [5, 6, 8]
    ])

    # category of the ith observation
    category_i = np.repeat(np.arange(L.shape[1]), 1000)
    mu_i = a_0 + a_1 * L[z,category_i]
    outcome_i = np.random.normal(
        loc=mu_i,
        scale=[sigma]*len(mu_i)
    )

And here is the most straightforward PyMC3 model to fit the data:

    with pm.Model() as model:

        sigma = pm.HalfNormal('sigma', sigma=2)
        a_0 = pm.Normal('a_0', mu=0, sigma=1)
        a_1 = pm.Normal('a_1', mu=0, sigma=1)
        z = pm.Categorical('z', np.ones(a.shape[0]))

        mu_i = a_0 + a_1 * theano.shared(L)[z][category_i]

        outcome_i = pm.Normal(
            'outcomes',
            mu=mu_i,
            sigma=sigma,
            observed=outcome_i,
        )

        trace = pm.sample(
            cores=1,
            return_inferencedata=True
        )

The results are pretty good and close to the true value:

However, I am going to have to eventually fit a huge amount of data (with 30k possible values for \zeta and millions of observations), and therefore efficiency is crucial (and possibly I’ll use the variational api). But the discrete parameter \zeta forces PyMC3 to use CategoricalGibbsMetropolis, which makes sampling slower than it could be. I was wondering if there is some way to manipulate the model so that instead of sampling \zeta directly, maybe a probability vector is sampled directly?

All the cases I’ve seen around (e.g. in here) use a Normal mixture, but in those models there is a Dirichlet distribution even in the non-marginalized version, so I am unsure how to apply those models to my model above. Also I am unsure if marginalizing \zeta out is the right approach, since \zeta is precisely the value I am interested in!

Thank you very much for your help, any pointers/hints would be much appreciated!

1 Like

I notice that \zeta is not indexed by trial, which would defeat the attempts to marginalize. If \zeta_i were indexed, you could concentrate not on the specific value of \zeta but on the frequency of occurrence (hence a Dirichlet). As \zeta is not indexed, instead you’re asking for a posterior distribution over a single value. So marginalizing out \zeta in this case eliminates it from the distribution. However, it may be useful to do this in an analysis as follows:

(1) Marginalize out \zeta to obtain posteriors for other parameters

\tilde e_i \sim \frac{1}{m}\sum_{k=1}^m \mathcal{N}(\alpha_0 + \alpha_1 L[k, c_i], \sigma) = \mathcal{N}(\alpha_0 + \alpha_1 L_\mu[c_i], \sigma)\;\;\; L_\mu[j] = \frac{1}{m}\sum_{k=1}^m L[k,j]

(2) Use the marginal model with pm.sample to obtain traces for \sigma, \alpha_0, \alpha_1

(3) Use the traces to evaluate the likelihoods:

P[\mathcal{X}|\zeta = k] = \frac{1}{N} \sum_{t=1}^N \phi_\mathcal{N}(\mathcal{X}|\mu = \hat \alpha_0[t] + \hat \alpha_1[t] L[k, C]; \sigma=\hat \sigma[t])

where \hat \sigma[t] is the t th value of the trace for \sigma. As your prior is flat, the posterior will be proportional to these conditional probabilities. You could do this directly using numpy/scipy; or you could put in a likelihood evaluation with pm.Deterministic to get it to pop out of the trace (though if L really has thousands of rows, it’s probably better to do it through numpy/scipy).

1 Like

Thank you very much, this is beautiful!

I had a feeling it had to do with marginalization but as you said \zeta disappears (now I also understand what the Dirichlet in the other models was about) so I couldn’t see quite how it could be made useful, which your answer makes very clear.

I have two further clarification questions, if you have time of course! I follow points 1 and 2 but I’m not sure I follow point 3 - I understand what it does but not quite how it’s doing it. Afaiu, (3) is calculating the cumulative probability of the whole dataset for each posterior sample + a specific value of k (which was marginalized out when sampling originally), and then averaging across samples. I have three questions about this:

  1. What is C? I mean, how does it relate to \mathbf{c}?
  2. What is \mathcal{X}?
  3. Why use the average cumulative probability?

I’d also be grateful for any source that discusses this method, I realize that explaining this is a lot of effort.

1 Like

The notation:

\phi_\mathcal{N}(\mathcal{X}|\mu=\hat \alpha_0[t] + \hat \alpha_1[t]L[k,C]; \sigma=\hat \sigma[t])

is an abusive way of notating “the likelihood of your data under the t-th sample”, meaning:

\prod_{i=1}^M \phi_\mathcal{N}(x_i | \mu = \hat \alpha_0[t] + \hat \alpha_1[t] L[k, c_i]; \sigma = \hat \sigma[t])

here \phi_\mathcal{N} is the pdf and not the cdf.

Now the summation \frac{1}{N}\sum_{t=1}^N is just taking the average over the entire trace. This is pure Monte Carlo integration, where the points are sampled from the joint posterior of \alpha_0, \alpha_1 and \sigma.

1 Like

Ach, I spent some time thinking about this and I got more confused instead of more clear.

If I run the following code for the marginalized distribution (everything else being the same as the first post), the posterior over the non-marginalized variables is very different from the posterior for those same variables in the model in my first post above:

with pm.Model() as model:

        sigma = pm.HalfNormal('sigma', sigma=2)
        a_0 = pm.Normal('a_0', mu=0, sigma=1)
        a_1 = pm.Normal('a_1', mu=0, sigma=1)

        mu_i = a_0+a_1*theano.shared(a).mean(axis=0)[category_i]

        outcome_i = pm.Normal(
            'outcomes',
            mu=mu_i,
            sigma=sigma,
            observed=outcome_i,
        )

        trace = pm.sample(
            cores=1,
            return_inferencedata=True
        )

(Although it’s very possible that I’m doing something wrong with the maths/model above)

Then, in point 3 from the posterior samples we calculate:

P[\mathcal{X}|\zeta = k] = \frac{1}{N} \sum_{t=1}^N \phi_\mathcal{N}(\mathcal{X}|\mu = \hat \alpha_0[t] + \hat \alpha_1[t] L[k, C]; \sigma=\hat \sigma[t])

This is a Montecarlo approximation for every k (i.e. for each of the m possible values of \zeta) of:

\int_{-\infty}^\infty \int_{-\infty}^\infty \int_0^\infty p (\textbf{e} \mid \zeta=k, \sigma, \alpha_0, \alpha_1, \textbf{c}) p(\sigma, \alpha_0, \alpha_1 \mid \textbf{e}, \textbf{c}) \; \mathrm{d}\sigma \mathrm{d}\alpha_0 \mathrm{d}\alpha_1

But I’m not seeing how this last expression is equal to the value of interest, namely the likelihood p(\textbf{e} \mid \zeta=k, \textbf{c}) (from which we can easily calculate the posterior).

Sorry this is getting quite long. I’d be grateful for any help! Thank you!

1 Like

The posteriors are conditional on your data \mathcal{X}, not on the value \mathbf{e} so:

\int\int\int_{\mathcal{R}} p(\zeta=k|\sigma, \alpha_0, \alpha_1, \mathbf{c}, \mathcal{X})p(\sigma, \alpha_0, \alpha_1 | \mathbf{c}, \mathcal{X})d\sigma d\alpha_0 d\alpha_1 = \int \int \int_\mathcal{R} p(\zeta=k, \sigma, \alpha_0, \alpha_1 | \mathbf{c}, \mathcal{X}) d\sigma d\alpha_0 d\alpha_1
= p(\zeta=k| \mathbf{c}, \mathcal{X})

If the posterior is really different for the marginal than the full model, then you may want to perform a few prior predictive predictive checks to see if a uniform prior is appropriate. Any categorical prior can be integrated out, in your case, by replacing the row average of L with a weighted row average.

1 Like

Maybe I’m getting a bit confused on the terminology, specifically what ‘data’ means here.

My observations are a bunch of tuples \langle e_i, c_i \rangle , where e_i is the cost of observation i (the outcome) and c_i is the category of observation i (the predictor). I am reorganizing into the vectors \textbf{e} and \textbf{c} for convenience. c here is the independent variable and e the dependent variable that I want to regress on c. Wouldn’t the posterior be conditional on \textbf{e} and \textbf{c}, which constitutes all the data I have? I guess I’m still a bit confused as to what \mathcal{X} means here beyond just \textbf{e}.

Sidelining the \mathcal{X} thing I’m not super clear on, I understand how:

\begin{align} p(\zeta \mid \textbf{c}, \mathcal{X}) & = \mathbb{E}_{p(\sigma, \alpha_0, \alpha_1 | \mathbf{c}, \mathcal{X})}\left[ p(\zeta=k|\sigma, \alpha_0, \alpha_1, \mathbf{c}, \mathcal{X}) \right] \\ & =\int\int\int_{\mathcal{R}} p(\zeta=k|\sigma, \alpha_0, \alpha_1, \mathbf{c}, \mathcal{X})p(\sigma, \alpha_0, \alpha_1 | \mathbf{c}, \mathcal{X}) \mathrm{d} \sigma \mathrm{d}\alpha_0 \mathrm{d}\alpha_1 \end{align}

But I thought the procedure in step 3 above was suggesting to first calculate the likelihood of the data given specific values of \zeta and then marginalize it, rather than go directly with p(\zeta=k|\sigma, \alpha_0, \alpha_1, \mathbf{c}, \mathcal{X}) (which I am not sure how to calculate).

I am going to think more about all this and report progress! Thank you!

1 Like

I apologize. The notation is a bit subtle, and I screwed it up. Thinking generally, you have your parameters of interest (call them \zeta), nuisances (call them \theta), and the data (call it \mathcal{X}); and you have some likelihood P(\mathcal{X}|\zeta, \theta) and priors P(\zeta, \theta) = P(\zeta)P(\theta). The approach I’ve outlined is:

(1) Q(\mathcal{X}|\theta) = \int P(\mathcal{X}|\zeta, \theta)P(\zeta)d\zeta
(2) Q_\mathrm{post}(\theta|\mathcal{X}) = \frac{Q(\mathcal{X}|\theta)P(\theta)}{\int Q(\mathcal{X}|\theta)P(\theta)d\theta}
(3) P_\mathrm{mar}(\mathcal{X}|\zeta) = \int P(\mathcal{X}|\zeta, \theta)Q_\mathrm{post}(\theta|\mathcal{X})d\theta
(4) P_\mathrm{post}(\zeta | \mathcal{X}) = \frac{P_\mathrm{mar}(\mathcal{X}|\zeta)P(\zeta)}{\int P_\mathrm{mar}(\mathcal{X}|\zeta)P(\zeta)d\zeta}

In your case \mathcal{X} = (e, c), \theta = (\sigma, \alpha_0, \alpha_1) and \zeta = (\zeta_i). Because \zeta_i is discrete, (1) is a sum rather than integral; and (4) is proportional to (3).

This is just one pass of generalized E-M, after (4) you can go back to (1) and replace P(\zeta) with P_\mathrm{post}(\zeta|\mathcal{X}) and repeat the process. The procedure does converge to the true posteriors. And if instead of computing the full integral at each step, you instead use only single sample from the (iteratively-updated) posteriors, this procedure is exactly Gibbs sampling.

Given that you have lots of data, and that a categorical distribution is not particularly complicated, I would expect the procedure to converge quickly; I assume one iteration is enough.

3 Likes

Yes this is very clear and incredibly helpful, thank you so much for taking the time!

There’s only one point left that I’m still a bit confused about, in step (1) where we marginalize P(\mathcal{X}, \zeta \mid \theta) over the \zeta. The idea is that the datapoints are drawn from the following distribution, which encodes the likelihood:

\tilde e_i \sim \frac{1}{m}\sum_{k=1}^m \mathcal{N}(\alpha_0 + \alpha_1 L[k, c_i], \sigma) = \mathcal{N}(\alpha_0 + \alpha_1 L_\mu[c_i], \sigma) \\ \textrm{where } L_\mu[j] = \frac{1}{m}\sum_{k=1}^m L[k,j]

which implies the generative model:

\prod_{i=1}^{len(\textbf{e})} p(\sigma)p(\alpha_0)p(\alpha_1) \phi_\mathcal{N}(\textbf{e}_i|\mu = \alpha_0 + \alpha_1 \frac{1}{m}\sum_{j=1}^m L[j, \textbf{c}_i]; \sigma= \sigma)

However, calculating the marginal generative model from ‘scratch’:

\begin{align} p(\sigma, \alpha_0, \alpha_1 \mid \textbf{e}, \textbf{c}) & = \sum_{j=1}^m p(\sigma, \alpha_0, \alpha_1, \zeta_j \mid \textbf{e}, \textbf{c}) \\ & \propto \sum_{j=1}^m \left( p(\sigma)p(\alpha_0)p(\alpha_1)p(\zeta_j)\prod_{i=1}^{len(\textbf{e})} p(\textbf{e}_i \mid \alpha_0, \alpha_1, \zeta_j, \textbf{c}_i, \sigma) \right) \\ & = p(\sigma)p(\alpha_0)p(\alpha_1) \sum_{j=1}^m \prod_{i=1}^{len(\textbf{e})} p(\textbf{e}_i \mid \alpha_0, \alpha_1, \zeta_j, \textbf{c}_i, \sigma)\\ &= p(\sigma)p(\alpha_0)p(\alpha_1) \sum_{j=1}^m \prod_{i=1}^{len(\textbf{e})} \phi_\mathcal{N}(\textbf{e}_i|\mu = \alpha_0 + \alpha_1 L[j, \textbf{c}_i]; \sigma= \sigma) \end{align}

(assuming \zeta has a uniform prior).

But isn’t this in general different from the result of marginalizing over \zeta for the individual datapoints as suggested above?

The only thing that comes to mind is that

\sum_{j=1}^m \prod_{i=1}^{len(\textbf{e})} \phi_\mathcal{N}(\textbf{e}_i|\mu = \alpha_0 + \alpha_1 L[j, \textbf{c}_i]; \sigma= \sigma)

can be rewritten as a mixture of multivariate Gaussians with mean vector \mathbf{\mu} = \left[ \alpha_0 + \alpha_1 L[j, \textbf{c}_1], \cdots, \alpha_0 + \alpha_1 L[j, \textbf{c}_n] \right] and a diagonal covariance matrix \sigma I_n. But I’m not sure that helps.

Anyway, I’ll post the implementation of the series of steps you described above when I write it for future reference! Thanks again!

1 Like

I haven’t read this thread in detail, so sorry in advance if I am only adding noise, but marginalization of latent discrete variables and later recovery from posterior mixture weights is something I have been studying a bit.

In particular there is this paper I have been meaning to read in more detail for some time now, that discusses the difference in speed/accuracy from explicit categorical sampling vs marginalization+posterior recovery:

https://esajournals.onlinelibrary.wiley.com/share/TUDMUTFR3BDAGZZCJA2M?target=10.1002/eap.2112

Maybe this relates to your last questions?

Edit: From glancing at it, the difference in accuracy might only be relevant when there are interactions between variables that determine the likelihood of being in a category vs another, whereas in my models this is usually not the case…

1 Like

Yes this looks very relevant! I am going to have a look and see if I can make heads or tails of it. Is there a discord or something where people discuss Bayesian modelling & specific papers? Seems like it might be a nice resource for fast discussions!

I prefer a forum format actually, as the information stays accessible long after the discussion waned.

Let me know if the paper gave you any solutions. In particular whether it suggests a reason for the difference in the explicit vs recovered discrete variables (as opposed to the boring explanation that there was a mathematical mistake somewhere above)! :slight_smile:

2 Likes

It’s good to have this discussion; because many of my suggestions happened intuitively, and it’s useful (even to me), to unpack the logic behind them.

You have a mixture prior:

P(e) \propto w_1 \mathcal{N}(e|\zeta=1) + w_2 \mathcal{N}(e|\zeta=2) + ... + w_m \mathcal{N}(e|\zeta = m)

but want to avoid using a Dirichlet-mixture marginalization, since m is very large. We are interested in some kind of posterior over variables other than \zeta. Since \zeta itself is not a mixture (it is fixed for all trials, but you simply don’t know the initial state) I suggested replacing the mixture with a moment-matched distribution, since

\mathbb{E}[e] = w_1(\alpha_0 + \alpha_1 L[\zeta=1]) + w_2(\alpha_0 + \alpha_1 L[\zeta=2]) + ...
=\alpha_0 + \alpha_1 \bar L_w

If you really wanted to you could match the second moment as well with

\mathrm{Var}[e] = \sigma^2 - (\alpha_0 + \alpha_1 \bar L_w)^2 + [w_1(\alpha_0 + \alpha_1 L[\zeta=1])^2 + ... + w_m(\alpha_0 + \alpha_1 L[\zeta = m])^2]

This is probably the ultimate source of the \sigma not lining up between the two models.

Again, this would be completely wrong-headed if the \zeta_i could change; but since they are fixed we’re replacing uncertainty over a state with a single test distribution. And we’re taking this approximation since, as you say, m is large.

If m were small, you could apply Mixture directly:

with pm.Model() as model:

        sigma = pm.HalfNormal('sigma', sigma=2)
        a_0 = pm.Normal('a_0', mu=0, sigma=1)
        a_1 = pm.Normal('a_1', mu=0, sigma=1)
        
        pZ = pm.Dirichlet('pZ', np.ones(L.shape[0]))
        
        # for each category
        for cat in range(L.shape[1]):
            obs_idx = np.where(category_i == cat)[0]
            muZ = a_0 + a_1 * L[:, cat]
            mix_lik = pm.NormalMixture('e_%d' % cat, w=pZ, mu=muZ, sigma=sigma, observed=outcome_i[obs_idx])
        
        trace = pm.sample()

Edit:

Although technically you would want to use your own version of Mixture since right now Mixture is giving the likelihood

L(x, \theta, w) = \prod_{j=1}^n \left(\sum_{i=1}^k w_i f_i(x_j|\theta_i)\right)

Whereas, since the \theta is fixed, you want:

L(x, \theta, w) = \sum_{i=1}^k w_i \prod_{j=1}^n f_i(x_j|\theta_i)

Edit x2:

It should be something like

with pm.Model() as model:

        sigma = pm.HalfNormal('sigma', sigma=2)
        a_0 = pm.Normal('a_0', mu=0, sigma=1)
        a_1 = pm.Normal('a_1', mu=0, sigma=1)
        
        pZ = pm.Dirichlet('pZ', np.ones(L.shape[0]))
        
        # build up the likelihoods over each state
        z_logp = tt.zeros((L.shape[0],), dtype='float')
        for z_i in range(L.shape[0]):
            zi_logp = 0.
            for cat in range(L.shape[1]):
                obs_idx = np.where(category_i == cat)[0]
                muZ = a_0 + a_1 * L[z_i, cat]
                zi_logp = zi_logp + tt.sum(pm.Normal.dist(mu=muZ, sd=sigma).logp(outcome_i[obs_idx]))
            z_logp = tt.set_subtensor(z_logp[z_i], tt.log(pZ[z_i]) + zi_logp)  # log (w_i * pr[X|z_i])
            
        lp3 = pm.Deterministic('logp', z_logp)
        
        # total logp is log(w1 * pr[X|z_1] + w2 * pr[X|z_2] + ...)
        tot_logp = pm.math.logsumexp(z_logp)
        pot = pm.Potential('e', tot_logp)
        trace = pm.sample(500, cores=3, chains=6, init='advi+adapt_diag')

Something a little strange is going on (@OriolAbril ?) where pZ does not appear to be converging, but the likelihoods are consistent with it doing so:

np.mean(trace['logp'], axis=0)
array([-1.62101175e+07, -1.38083910e+07, -1.68103095e+07,  2.65029998e+03])
np.mean(trace['pZ'], axis=0)
array([0.19917654, 0.20252834, 0.1982173 , 0.40007782])

How can we be seeing total_logp values of ~~ logp[3]; but not have pZ ~ [0, 0, 0, 1] ?

2 Likes

The scale of logp is way too different to be able to use compact=True. The first 3 values have means of the order of 1e7, and it looks like they have std of the order of 1e7 (or at least 1e6), whereas the 4th is ~1e3 with a similar or smaller scale. Therefore, plotting them together gives this resut, the 3 first have a “flat” distribution that gets confused with the lower axis of the plot, and the 4th looks like a delta function.

I’ll read the rest of the topic more carefully to see if I can add something else.

Sorry, the logp plot was not the issue. The issue is that logp_tot is clearly positive (indeed, it’s about equal to the mean of logp[3]); which means that

z_0 * (-1.6 \times 10^7) + z_1 * (-1.4 \times 10^7) + z_2 * (-1.7 \times 10^7) + z_3 * 2.6 \times 10^3 \approx 2.6 \times 10^3

with z_0 + z_1 + z_2 + z_3 = 1. Thus i would expect z_3 \approx 1 and z_{\neq 3} \approx 0; and yet the distribution of z (pZ on the plot) shows a rather broad distribution.

1 Like

I still don’t completely understand what is happening, but I have two questions.

First one is what is logp_tot? I can’t see it in the model

Second one, doesn’t logp already contain the multiplication times the weight?

Sorry, stale paste. Logp_tot is just the potential:

tlp = pm.Deterministic('logp_tot', tot_logp)
pot = pm.Potential('e', tot_logp)

and logp does have the multiplication times the weight. From the tot_logp i expect the weights to have to be near (0, 0, 0, 1) [or else the total logp would be negative, since the other likelihoods are on the order of -5.0e7 vs +2.5e3], but the weights in the trace are really broad

It’s very instructive for me to see what approximations people do, e.g. in courses you don’t encounter moment matching used in this context very much. Is the idea that if there is loads of data, the mixture itself is probably going to approximate a normal dist closely enough that matching the first two moments will be close to the true dist?

I understand the bit about

\sum_{i=1}^k w_i \prod_{j=1}^n f_i(x_j|\theta_i) \not = \prod_{j=1}^n \left(\sum_{i=1}^k w_i f_i(x_j|\theta_i)\right)

and that the first Mixture gives the latter while want the former, and I can follow the code in edit 2 (which was actually really nice to see because I suspected that using a Potential + logsumexp might be one way to go. I’ve seen it done a lot by STAN people too).

I am not totally clear on what role the Dirichlet is playing in all this. The model implicit in the code is, afaiu:

p(\sigma, a_0, a_1) \phi_{\mathrm{Dir}}( w ; \mathbf{1}_m )\sum_{i=1}^k w_i \prod_{j=1}^n f_i(x_j|\theta_i)

The weight w_i is the probability that \zeta=i right? And the Dirichlet is putting a (uniform) prior on vectors of weights (implying a uniform marginal prior over the values of \zeta). So it’s almost like the Dirichlet is putting a prior over the priors over \zeta? But I’m not sure how to interpret the mean posterior of Pz (np.mean(trace['pZ'], axis=0)).

Here’s what I would have done naively to get the marginalized posterior: instead of having a prior over priors I would use the original prior over \zeta:

p(\sigma)p(\alpha_0)p(\alpha_1) \sum_{j=1}^m \prod_{i=1}^{len(\textbf{e})} \phi_\mathcal{N}(\textbf{e}_i|\mu = \alpha_0 + \alpha_1 L[j, \textbf{c}_i]; \sigma= \sigma)

and then do the generalized EM algorithm you described above.

Adapting your code for the first step, i.e. the marginalization:

with pm.Model() as model:

        sigma = pm.HalfNormal('sigma', sigma=2)
        a_0 = pm.Normal('a_0', mu=0, sigma=1)
        a_1 = pm.Normal('a_1', mu=0, sigma=1)
        
        # build up the likelihoods over each state
        z_logp = tt.zeros((L.shape[0],), dtype='float')
        for z_i in range(L.shape[0]):
            zi_logp = 0.
            for cat in range(L.shape[1]):
                obs_idx = np.where(category_i == cat)[0]
                muZ = a_0 + a_1 * L[z_i, cat]
                zi_logp = zi_logp + tt.sum(pm.Normal.dist(mu=muZ, sd=sigma).logp(outcome_i[obs_idx]))
            z_logp = tt.set_subtensor(z_logp[z_i], zi_logp)  # log (pr[X|z_i])
            
        lp3 = pm.Deterministic('logp', z_logp)
        
        # total logp is log(pr[X|z_1] + pr[X|z_2] + ...)
        tot_logp = pm.math.logsumexp(z_logp)
        pot = pm.Potential('e', tot_logp)
        trace = pm.sample(500, cores=3, chains=6, init='advi+adapt_diag')

And then doing the steps (3) and (4) you described above to get P_{post}(\zeta \mid \mathcal{X})

Is the reason for using a Dirichlet that it does steps (3) and (4) implicitly somehow?

Thanks again, I’m learning a lot from this discussion!

p(\sigma)p(\alpha_0)p(\alpha_1)\sum_{j=1}^m \prod_{i=1}^N \phi_\mathcal{N}(e_i|\mu=\alpha_0 + \alpha_1 L[j, c_i]; \sigma=\sigma)

You still need a p(\zeta = j):

p(\sigma)p(\alpha_0)p(\alpha_1)\sum_{j=1}^m p(\zeta = j) \prod_{i=1}^N \phi_\mathcal{N}(e_i|\mu=\alpha_0 + \alpha_1 L[j, c_i]; \sigma=\sigma)

You can either

(1) Fix p(\zeta = j) = \frac{1}{m}, eliminating m dimensions from sampling, and iteratively update the posterior; or:

(2) Set p(\zeta = j) = \delta_j with \delta \sim \mathrm{Dir}(1)

in which case the posterior distribution is immediate.