How to do model comparison with a dummy variable


I’m trying to do a model comparison in a hierarchical meta model using a dummy variable that encodes model choice. Sampling from the full model doesn’t work well because the chains get me stuck in local optima: every chain “picks” one model and sticks to it, updating only the corresponding parameters correctly.

To work around this, I have tried to sample the posteriors for all parameters first (training each submodel independently), and to then infer the the posterior of the model variable from based on these posteriors. However, I’m not really sure how to do that correctly:

  • When using sample_posterior_predictive, the model variable is sampled from its prior, not its posterior. (I was able to get this to work in numpyro using its Predictive class).
  • I’ve instead tried to use an alternative version of the meta model, in which the parameter priors are the previously estimated posteriors. That works but I don’t think it’s correct since the observations are the same as before, so I essentially observe the same data twice. I can see that the sampled parameter values are more narrowly distributed than before.
  • Using pyro, I was able to get Variational Inference to work (numpyro gives me weird NaNs with the same model). This gives me the same results as the predictive approach in numpyro.

Any ideas, how to do this? (By the way, I’m aware of other forms of model comparison, but this is intended as a demonstration of the principles of Bayesian modeling, not an actual application, so I think showing the meta-model / Bayes-factor approach would be cool.)

Here is the code that I’m using. I’m trying to model the sizes of intervals of consecutive notes in polyphonic music. The interval size is always estimated as a geometric distribution (non-negative integers), but the three competing models use different predictors:

  • model 1 assumes a globally constant parameter
  • model 2 assumes different parameters for each voice (4 voices, the pieces are string quartets)
  • model 3 assumes the parameter to depend on the pitch of the preceding note, or the “register” in musical terms (logistic)

You can see this reflected in the meta model.
(The non-flat prior on the model is just there to show that sample_posterior_predictive indeed samples from the prior.)

# given data
observations = [...] # observed step sizes
staff = [...] # the staff/voice of each datapoint
p0 = [...] # the pitch of the first note corresponding to each datapoint

with pm.Model() as model_meta:
    # model choice
    model_choice = pm.Categorical("model_choice", [0.5, 0.3, 0.2])
    # global model
    theta_global = pm.Beta("theta_global", 0.5, 0.5)

    # voice model
    theta_voice = pm.Beta("theta_voice", 0.5, 0.5, shape=4)

    # register model
    a = pm.Normal("a_register", 0, 10)
    b = pm.Normal("b_register", 0, 10)
    theta_register = pm.math.sigmoid(p0*a + b)

    # observation
    theta = ptn.tensor.stack((
        ptn.tensor.fill(p0, theta_global),
    pm.Geometric("obs", p=theta[model_choice], observed=observations+1)

I use the following auxiliary model to obtain posterior samples for the parameters:

with pm.Model() as model_joint:
    # global model
    theta_global = pm.Beta("theta_global", 0.5, 0.5)
    pm.Geometric("obs_global", p=theta_global, observed=observations+1)

    # voice model
    theta_voice = pm.Beta("theta_voice", 0.5, 0.5, shape=4)
    pm.Geometric("obs_voice", p=theta_voice[staff], observed=observations+1)

    # register model
    a = pm.Normal("a_register", 0, 10)
    b = pm.Normal("b_register", 0, 10)
    theta_register = pm.math.sigmoid(p0*a + b)
    pm.Geometric("obs_register", p=theta_register, observed=observations+1)

    # draw samples
    idata_joint = pm.sample(5_000, chains=4)

I then try to infer model_choice like this:

with model_meta:
    idata_model_choice_meta = pm.sample_posterior_predictive(idata_joint, var_names=["model_choice"])


Which gives me samples from the prior (0.5, 0.3, 0.2)

Any advice is welcome.

1 Like

This approach is taken in Kruschke’s DBDA book 2nd ed. Chapter 10. There’s a PyMC port by @cluhmann here Maybe thus will be useful?

1 Like

Awesome, thanks for the hint. I’ve tried the pseudo prior approach and it works really well for avoiding the sampling problem.

I still wonder, if there is a way to make the step-wise inference work, or if it is valid in the first place. I feel like I’m making a simplification there that doesn’t work.

In any case, i’ve written up a little notebook that includes both the pseudo-prior and the VI solution for my models, in case that’s interesting for anyone.

I think the Kruschke approach/example is illustrative, but should be avoided in practice (e.g., he never uses HMC/NUTS). Instead, a typical approach would be to marginalize out the indicator variable. In the case of just 2 models, you can just use a single, continuous “mixing” parameter bound to [0, 1] (e.g., using a Beta prior) and make your likelihood a weighted-mixture of the 2 models. With more than 2 models, you’re looking at a Dirichlet-distributed set of mixing parameters.

Mh, I see. So if I got this right, I would

  • add yet another level of hierarchy to the model and
    introduce the the model probability as another variable (let’s say \mu), so the model becomes
    p(\mu, m, \theta, x) = p(\mu) \cdot p(m | \mu) \cdot p(\theta) \cdot p(x | \theta_m, m)
  • then marginalize out the model choice variable m analytically:
    p(\mu, \theta, x) = p(\mu) \cdot p(\theta) \cdot p(x | \theta, \mu)
    (where $p(x | \theta, \mu) is the weighted mixture likelihood that you mentioned)
  • and finally look at the posterior of \mu instead of the posterior of m.

Is that right?

You would just replace your indicator parameter m with the mixing parameter \mu so that \mu =p(m) [edit: or, to be a bit more precise, \mu =p(m=1)].

In the case of 3 sub-models, you would need a Dirichlet-distributed parameter something like this:

# model definitions
# replace with something useful
model_components = [
        # model 1
        # model 2
        # model 3

# mixture weights
w = pm.Dirchlet("w", a=np.ones(num_models))

# likelihood
like = pm.Mixture(

Hi, sorry for getting back to this so much later.

I finally tried to apply the pm.Mixture() approach to my model, and I’m running into similar problems as with a naive mixture (where the dummy variable is not marginalized). After increasing target_accept to 0.9, I don’t get divergences, but sampling is extremely slow and the chains still get stuck in different regions. Is there something wrong with the model or is marginalizing out the dummy variable not sufficient in a case like this?

Here is the full model specification:

with pm.Model() as model_mixture:
    # model weights
    model_weights = pm.Dirichlet("model_weights", np.ones(3))

    # model 1: global parameter
    theta_global = pm.Beta("theta_global", 0.5, 0.5)

    # model 2: parameters per voice
    theta_voice = pm.Beta("theta_voice", np.full(4, 0.5), np.full(4, 0.5))

    # model 3: parameter depends on preceding pitch
    a = pm.Normal("a", 0, 10)
    b = pm.Normal("b", 0, 10)
    theta_register = pm.math.sigmoid(p0*a + b)

    # mixture components
    components = [

    # observation
    pm.Mixture("obs", w=model_weights, comp_dists=components, observed=observations+1)

    idata_mixture = pm.sample(1000, chains=4, target_accept=0.9)
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4882 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See for details

And here is the trace:

Would be curious to hear what you think about this.

That all looks reasonably good to me. There’s only a single chain that is failing to mix. I would expect that tweaking the tuning routine might help to take care of that. As for the fact that the model ultimately ends up preferring a mixture in which a single component is dominant, that’s likely to be a function of your components and your data.

I see. Thanks for the feedback!