Convergence issue with finite mixture model with nutpie sampler

Hi there! I have recently switched from stan to pymc especially because of nutpie`s speed compared to stan’s sampler.

I was playing around with estimating some basic mixture models where the data come from two normal distributions with different means and standard deviations. I was trying to recreate the stan example shown here regarding a finite mixture model in pymc. Specifically, I am using pm.Potential() because I eventually want specify mixtures of arbitrarily many processes and pm.Potential() seemed like the way to go if I need flexibility. The model is the following:

import pymc as pm
import numpy as np
import arviz as av

# Simulate some mixture of two normals
np.random.seed(42)
y_obs = np.concatenate([
    np.random.normal(-1, 2, size=150),
    np.random.normal(3, 1, size=350)
])

with pm.Model() as model:
    # means and SDs
    mu1 = pm.Normal('mu1', mu=-1, sigma=2)
    mu2 = pm.Normal('mu2', mu=3, sigma=2)

    sigma1 = pm.HalfNormal('sigma1', sigma=2)
    sigma2 = pm.HalfNormal('sigma2', sigma=2)

    # Mixing weights
    w1 = pm.Beta('w1', alpha=3, beta=7)  

    # Log-likelihood for each component
    logp1 = pm.logp(pm.Normal.dist(mu=mu1, sigma=sigma1), y_obs)
    logp2 = pm.logp(pm.Normal.dist(mu=mu2, sigma=sigma2), y_obs)

    # Mixture log-likelihood using log-sum-exp
    
    log_mix = pm.math.logsumexp(
        [pm.math.log(w1) + logp1,
         pm.math.log(1 - w1) + logp2],
        axis=0
    )

    # Add to model logp via Potential
    pm.Potential('mixture_logp', log_mix.sum())

        # Sample posterior
    trace = pm.sample(1000)

az.summary(trace)

With the default pymc sampler, the model converges just fine and the parameters recover reasonably. However, once I switch to nutpie and change the penultimate line to

    trace = pm.sample(1000, nuts_sampler="nutpie")

I consistently experience convergence issues (for each parameter, nutpie seems to sample from two distinct areas of the posterior once I look at the pairs plot with az.plot_pair(trace)).

Does anyone have any insight into what may be happening? I am relatively new to pymc and don’t use python much, so I apologize if I am missing something obvious. Any suggestion is much appreciated!

I ran your code as provided several times and never had any convergence issues. Can you make sure all packages (pymc, pytensor, nutpie) are up to date? You can run the following and share the results:

import pymc as pm
import pytensor 
import nutpie as ntp

print(f'pymc: {pm.__version__}')
print(f'pytensor: {pytensor.__version__}')
print(f'nutpie: {ntp.__version__}')

Also have you seen the automatic marginalization feature? It might lead to more natural model definitions (in the generative sense – if you’re coming from Stan pm.Potential might feel more natural. But PyMC is written in a way that really wants you to write the model down in a forward, generative fashion).

You can rewrite your model in a more “pymc” way like this:

import pymc as pm
import numpy as np
import arviz as az
import pymc_extras as pmx

# Simulate some mixture of two normals
np.random.seed(42)

y_obs = np.concatenate([
    np.random.normal(-1, 2, size=150),
    np.random.normal(3, 1, size=350)
])

with pm.Model(coords={'obs_idx':range(y_obs.size), 'class':[0, 1]}) as model:
    y = pm.Data('y', y_obs, dims=['obs_idx'])
    
    # means and SDs
    mu = pm.Normal('mu', mu=[-1, 3], sigma=2, dims=['class'])
    sigma = pm.HalfNormal('sigma', sigma=2, dims=['class'])

    # Mixing weights
    w1 = pm.Beta('w1', alpha=3, beta=7)
    
    class_id = pm.Bernoulli('class_id', p=w1, dims=['obs_idx'])
    y_hat = pm.Normal('obs', mu=mu[class_id], sigma=sigma[class_id], observed=y)
    
marginalized_model = pmx.marginalize(model, [class_id])

with marginalized_model:
    # I had some issue with pickling using the numba backend, so I switched to jax
    idata = pm.sample(1000, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend':'jax'})

idata = pmx.recover_marginals(marginalized_model, idata, extend_inferencedata=True)
az.summary(idata, var_names=['~class_id', '~__'], filter_vars='like')

Since you just have mixture of gaussians in this case, you could also use pm.NormalMixture as shown here. Internally I think that’s will basically do what you’re doing by hand here. But marginalize will be fully flexible to anything you end up wanting to do. It also gives you the class assignment probabilities “for free”, if that’s interesting to you.

I also ended up using data containers and labeled dimensions, which we really really suggest people do. Tutorial here.

Thank you for the quick reply!

So, my library versions were: pymc: 5.22.0, pytensor: 2.30.3, nutpie: 0.13.1. I updated all the libraries to the latest versions and the problem was fixed once I specifically updated to nutpie: 0.15.1, the latest release.

I was aware of the marginilize function, but I wanted to check whether I could mimic what it did “by hand”. But I do see how it is the best way to go in the long run, so I’ll definitely be using that once I specify more complex models (and yes, class assignments are something that I’ll be looking at eventually).

I also really appreciate all the suggestions about best practices in pymc!

1 Like

Forgot to mention you can use nutpie with stan models: stan-usage – Nutpie

1 Like