Can you marginalize a mixture model where the draws from the different components are not independent?

I want to use a mixture model with an added constraint that the number of samples from each mixture component should be approximately equal. E.g. in the MWE below, I know that half the samples should come from each of the two components, and I ensure that this constraint is not violated too much by adding a potential term that gives a large penalty to samples that deviate too much from equal distribution.

The model samples ok with the BinaryGibbsMetropolis sampler, but I wonder if it would be possible to marginalize out the discrete variables in this case?

import pymc as pm
import numpy as np
import pandas as pd

ids = list(range(60))
hidden_assignment = [0, 1] * 30
mu = hidden_assignment.copy()
df = pd.DataFrame({"id": ids, "hidden_assignment": hidden_assignment, "y": np.random.normal(mu, 0.5, 60)})

with pm.Model(coords={"id": ids}) as balanced_mixture_model:
    group_assignment = pm.Bernoulli("group_assignment", p=0.5, dims="id")
    mu_0 = pm.Normal("mu_0", mu=0, sigma=0.1)
    mu_1 = pm.Normal("mu_1", mu=1, sigma=0.1)
    mu = pm.math.switch(group_assignment, mu_1, mu_0)
    y = pm.Normal("y", mu=mu, sigma=.5, observed=df["y"], dims="id")
    # Above: Standard normal mixture model
    # Below: Penalty term to encourage equal group sizes
    num_group_1 = pm.math.sum(group_assignment)
    pm.Potential("num_group_1_penalty", -(num_group_1 - 30 )**2)
    trace_balanced = pm.sample()

Is there a reason why you dont do something like here

with pm.Model() as model:
    w = pm.Dirichlet('w', a=np.array([1, 1]))  # 2 mixture weights

    mu1 = pm.Normal("mu1", 0, 1)
    mu2 = pm.Normal("mu2", 0, 1)

    components = [
        pm.Normal.dist(mu=mu1, sigma=1),
        pm.Normal.dist(mu=mu2, sigma=1),

    like = pm.Mixture('like', w=w, comp_dists=components, observed=data)

Here, if you scale the parameter a in w by say 10 or 100 you can ensure that the weights of your mixture will be very close to 1/2 with only little variance.

The p=0.5 in the Bernoulli already makes the weights equal in the mixture. The problem is that there only will be an equal split on average, for a given sample of the model it’s possible that 20 of the individuals come from one component and 40 from the other, just by chance. Whereas I know that this should not happen, and I want to include this knowledge in the model. The potential that I added ensures this. But with this potential I don’t know how to marginalize the discrete assignments.

So you want the weights to be exact? Then you can set w=[1/2, 1/2] in the code I have posted above?

If I remove the pm.Potential from my model, it will be identical to yours with w=[0.5, 0.5]. But I need the potential term to not get samples that are inconsistent with what I know about how the data was generated. Without the potential, of the 60 observations, on average 30 will be assigned to the one component and 30 to the other, but by chance it usually deviates from this, and the deviations can be as large as 40/20. Whereas with the potential most samples are 30/30, and the largest deviations are 32/28.

Have to tried to see if MarginalModel can handle automatic marginalization for this case?

I can’t think of a straightforward generative model that would be compatible with this. It would require the indicator variables to follow some multivariate Bernoulli distribution instead of iid Bernoullis. MarginalModel doesn’t handle any multivariate discrete variables so it’s a non-starter

According to the docs for MarginalModel, “Deterministics and Potentials cannot be conditionally dependent on the marginalized variables.”

1 Like

That’s correct. Your Potential trick only works with the explicit sampled variables. If you were marginalizing the indicator variables you would need to compute the posterior probability of each indicator variable to apply the Potential penalty term which is not straightforward from the way PyMC models are built as a DAG of conditional dependencies.

Even if you can derive the logp of a multivariate Bernoulli that behaves as you want, marginalizing by enumeration will likely be computationally prohibitive (since the number of possible terms grows exponentially with size, whereas in the independent cases it’s linear), unless there are some nice symmetries that can be exploited.

You could model the indicator variables as a DiscreteMarkovChain that regresses to the mean of 0.5 with a couple of lag dependencies?

This is my intuition as well, I just wanted to check if I had missed something. In my case the model and dataset is small enough that it works to sample with the discrete variables, but of course a speedup would be nice.

Random google showed up that some kinds of correlated Bernoulli may correspond to suitable parametrized Beta-Binomials… you still have the enumeration issue though:

1 Like

Interesting, I didn’t know about this model. I think it could work, although it would take a little bit of math to figure out the correct kernel. But would this actually be faster? You’d get a very deep graph which seems like it would be slow to sample from?

The MarkovChan model can be marginalized efficiently, and grows linearly as well. You would just need to find a nice n-lagged transition matrix that in the short term brings the probability of trials closer to 0.5 fast enough but not absurdly so. (you could have a 1 step chain with p=[[0, 1], [1, 0]] of course, which will ensure any sequence has exactly 50% ones, but doesn’t allow any variation other than all even draws are 1 or all odd draws are 1).

On second thought I don’t think this model works. You would need a transition kernel that depends on the state, which I presume is impossible. Otherwise you would introduce short-range correlations that shouldn’t be there.

In fact it wouldn’t be a Markov model :sweat_smile:

Yes it will introduce short-range correlations. It’s just a hack. Those will however exist in your model right?

Why not?

No, in my model there’s only a global correlation, all ids are equally likely to come from one component or the other a priori, there’s nothing special about two adjacent ids.