Why does MAP vastly outperform sample in bayesian clustering?

I want to apply bayesian statistics in the context of clustering and decided to use normal mixture models for this purpose. The context that I want to use this needs a bit more generality than regular gaussian mixture models available in other packages (see below), hence why I undertook this task.

I put together a model that performs very similar to sklearn’s mixture.GaussianMixture when one uses regular clustering centers as initial values and centers of the normal distribution priors (still quite vague prior around these though, that is large SDs, see below) and uses find_MAP().

On the otherhand sample returns completely weird results unless you very specifically guide the priors. I will show the model before I write more:

def bayesian_clustering(data, nclusters_fit, conc=1, mu_sigma=10, alpha=2,
                        beta=1, est_centers=None, sample=False):
  
    '''
    est_centers if supplied must be of shape nclusters_fit, ndims
    where ndims = data.shape[1]. It is purpose to supply precalculated
    cluster centers (such as from k-means etc) as guidance to the system.
    By increasing mu_sigma or decreasing it you can determine how soft the
    guidance will be (smaller values are more informative and more guidance).

    if sample=True, then sampling will be used for finding parameters if
    not then MAP
    '''

    if est_centers is None:
        init = [np.linspace(-np.max(np.abs(data)), np.max(np.abs(data)),
                            nclusters_fit), None]
        center_mus = [init[0],  np.tile(data.mean(axis=0)[1:][:,None],
                                        (nclusters_fit,1))]
    else:
        init = [est_centers[:,0],  est_centers[:,1:]]
        center_mus = [est_centers[:,0], est_centers[:,1:]]

   
    ndims = data.shape[1]

    with pm.Model() as model:

        #priors
        μ0 = pm.Normal("μ0",
                       mu=center_mus[0],
                       sigma=mu_sigma,
                       shape=(nclusters_fit,),
                       transform=pm.distributions.transforms.univariate_ordered,
                       initval=init[0])

        μ1 = pm.Normal("μ1",
                       mu=center_mus[1],
                       sigma=mu_sigma,
                       shape=(nclusters_fit, ndims-1),
                       initval=init[1])

        σ = pm.InverseGamma("σ", alpha=alpha, beta=beta)
        weights = pm.Dirichlet("w", conc*np.ones(nclusters_fit))

        #transformed priors
        μ = pm.Deterministic("μ", ptt.concatenate([μ0[:,None], μ1], axis=1))
        components = [pm.Normal.dist(μ[i,:], σ) for i in range(nclusters_fit)]

        #likelihood
        pm.Mixture('like', w=weights, comp_dists=components, observed=data)

        if sample:
            trace = pm.sample(draws=4000, chains=6, tune=2000,
                              target_accept=0.95)
        else:
            MAP = pm.find_MAP()

    if sample:
        return trace, model

    return MAP, model

Some notes:
1- Note that I have used univariate normal, this is because the context I want to use this in later requires Censored distributions which is not available for MvNormal.
2- I have imposed an ordering constraint on the first component of the centers to remedy the labelling problem a bit, but I am not convinced that this would be enough. Supplying initial values hopefully does remedy a bit more too but feel like I need to come up with something else. After all each component evaluates independently from each other so other components probably won’t care about the ordering of the first component?

One of the things I have noticed by looking at the posteriors of components of the cluster centers is they are highly multi-modal and that is something which is, I think, expected for high dimensional mixture models. And in fact experiments with number of clusters=2 do seem to suggest that statistics for cluster centers are off because of “label switching”.

On the other hand, in almost all the tests I made, find_MAP() finds the true cluster centers and labels very precisely (when initial values are supplied), which I guess is expected, it converges to nearest minima. When it does not, the uncertainty reflected is in the “right way” (i.e when cluster centers for simulated data are too close or noise is too large etc the points in these clusters are confused). This only works well with sampling if I supply the cluster centers as initial values and keep the mu_sigma low such as mu_sigma=1. Anything above 5 starts creating problems.

The reason why I want to be also able to use sample reliably is down the line I am sampling likelihood of each point belonging to a cluster with something like this:


def sample_cluster_likelihoods(model, nclusters_fit, data, trace):
    '''
    sampling the cluster likelihoods from a given model and trace
    '''

    with model:

        μ = model.μ
        σ = model.σ
        w = model.w

        components = [pm.Normal.dist(μ[i,:], σ) for i in range(nclusters_fit)]

        log_p =\
            ptt.stack([pm.logp(components[i],data).sum(axis=1)
                       for i in range(nclusters_fit)])

        p = pm.math.exp(log_p)

        normalization = (w[:,None]*p).sum(axis=0)

        pm.Deterministic("cluster_likelihoods",
                         w[:,None]/normalization[None,:]*p)

        pps = pm.sample_posterior_predictive(trace,
                                             var_names=["cluster_likelihoods"])

    return pps

and would like to get HDI interval for these if I can (I can do the point estimate version when I have MAP).
So my questions would be

1- Is it safe to be using MAP in a model like this (as the simulated data results suggest)? The advantage of this over existing ones is I can modify the likelihood and also get probability per cluster for each data.
2- Is it ok to get the HDIs by doing the sampling with a relatively informative prior around a local minima found by say normal clustering or dare I say from the MAP?
3- Any suggestions on how to improve this model for better sampling and better remedying the label switching problem?

The full code include data generation functions etc and quite is length so I dumped it here:

note: I have also just come across this topic:

which seems to pretty much speak about what I am having difficulty with and it seems like “relabelling” the chains by hand seems like the most reasonable approach. I have however witnessed label switching even within a chain so this might need to be done on a per sample basis which could be quite slow. I guess one possibility is to add a “confusion cost matrix” that depends on an input baseline set of labels which maybe comes from MAP? Again the point would be to get HDI estimates (with a as soft guidance as possible) and not necessarily finding better centres than MAP.

Typically, if you are getting better results with a “simpler” method (MAP, metropolis, etc), your model has some kind of mis-specification. One of the best and most frustrating things about NUTS is that it fails very loudly.

I’m a beginner when it comes to mixture models, but I spent some time playing with your model. As written, you end up with really nonsense results in the non-sorted dimension. After some consultation with @ricardoV94 , I think it’s really important that you use a mixture of Multivariate Normals, rather than a mixture of Normals. My intuition (and likely yours) was that these are equivalent, but in the context of a mixture model they’re not. Using a mixture all Normals is like a mixture of 10 exchangable components, whereas a mixture of 5 MultivariateNormals encodes an unbreakable connection between each (x, y) pair.

After that change things got much more sensible. To get it across the finish line I had to tinker with priors and sampling options. I needed quite tight sigma priors to makes sure the distributions didn’t “bleed” into each other. I also got better results with the pymc NUTS implementation than jax, but I don’t have a good explanation for that. My hypothesis is that the uni-modal nature of ADVI pushed the mixture weights towards 1/n, which ends up being the right answer. Could also experiment with setting the concentration parameter higher. Without using the ADVI initialization, the sampler often found sparse solutions, which we really don’t want (in this problem anyway).

Here’s the data generation:

n_clusters = 5
data, labels = make_blobs(n_samples=1000, centers=n_clusters, random_state=10)

scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
plt.scatter(*scaled_data.T, c=labels)

image

And the model:

coords={"cluster": np.arange(n_clusters),
        "obs_id": np.arange(data.shape[0]),
        "coord":['x', 'y']}

# When you use the ordered transform, the initial values need to be 
# monotonically increasing
sorted_initvals = np.linspace(-2, 2, 5)

with pm.Model(coords=coords) as model:
    # Use alpha > 1 to prevent the model from finding sparse solutions -- we know
    # all 5 clusters should be represented in the posterior
    w = pm.Dirichlet("w", np.full(n_clusters, 10), dims=['cluster'])
    
    # Mean component
    x_coord = pm.Normal("x_coord", sigma=1, dims=["cluster"], 
                        transform=pm.distributions.transforms.ordered,
                        initval=sorted_initvals)
    y_coord = pm.Normal('y_coord', sigma=1, dims=['cluster'])
    centroids = pm.Deterministic('centroids', pt.concatenate([x_coord[None], y_coord[None]]).T,
                                dims=['cluster', 'coord'])
    
    # Diagonal covariances. Could also model the full covariance matrix, but I didn't try.
    sigma = pm.HalfNormal('sigma', sigma=1, dims=['cluster', 'coord'])
    covs = [pt.diag(sigma[i]) for i in range(n_clusters)]
    
    # Define the mixture
    components = [pm.MvNormal.dist(mu=centroids[i], cov=covs[i]) for i in range(n_clusters)]
    y_hat = pm.Mixture("y_hat",
                       w,
                       components,
                       observed=scaled_data,
                       dims=["obs_id", 'coord'])

    idata = pm.sample(init='advi+adapt_diag')

Results:

image

image

If you’re interested in the probability of each point, you can use a predictive model to go back and compute that:

with pm.Model(coords=coords) as likelihood_model:
    w = pm.Flat("w", dims=['cluster'])
    
    x_coord = pm.Flat("x_coord", dims=["cluster"])
    y_coord = pm.Flat('y_coord', dims=['cluster'])
    centroids = pm.Deterministic('centroids', pt.concatenate([x_coord[None], y_coord[None]]).T,
                                dims=['cluster', 'coord'])
    
    sigma = pm.Flat('sigma', dims=['cluster', 'coord'])
    covs = [pt.diag(sigma[i]) for i in range(n_clusters)]
    
    components = [pm.MvNormal(f'component_{i}', mu=centroids[i], cov=covs[i]) for i in range(n_clusters)]
    p_component = pm.Deterministic('p_component', pt.concatenate([pm.logp(d, scaled_data)[:, None] for d in components],
                                                        axis=-1),
                           dims=['obs_id', 'cluster'])
    label_hat = pm.Deterministic('label_hat', pt.argmax(p_component, axis=-1), dims=['obs_id'])
    idata_pred = pm.sample_posterior_predictive(idata, var_names=['p_component', 'label_hat'])
2 Likes

Hi,

Thanks alot for the detailed analysis.

1- Which version are you running? I am on 5.9.1 and running your code as is starts off with an unreasonable amount of slowness:

Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
 |--------------------| 0.02% [44/200000 00:09<12:28:29 Average Loss = 3,736.7]

I just ran the code above as is after adding the necessary imports. It does not seem like an issue with advi. If I try:

idata = pm.sample(1000, tune=1000, chains=3)

I get:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
 |--------------| 0.35% [21/6000 00:15<1:12:28 Sampling 3 chains, 0 divergences]

which does not improve much (in fact gets worse!). I thought this was an issue which could be resolved by supplying correct initial values (as a test) suprisingly, it did not fix anything. I removed scaling, no luck.
Finally it worked in a reasonable amount of speed only when I changed MvNormal to normal. So at the moment I am not able to play around with your code and whatever I write below is a bit speculative.

2- You are using MvNormal with diagonal covariance which means that components of the MvNormal are independent normal distributions (though they might have difference variances since that is left as a parameter). So from a mathematical point of view, there is no difference from using a set of independent normals. I am confused as to why that would create a difference unless the sampler treats sets of normals and a MvNormal with diag cov differently. Unfortunately, MvNormal is not an option for me because later on I am going to use it for data where censoring is present. Since censoring is not applicable for MvNormal I can not use it. There is also no reason currently for me to use MvNormal because I have a pretty good guess that features I am trying to cluster on are independent from each other (i.e clusters are spherical from a geometric point of view). Can you or @ricardoV94 elaborate as to why there would be a difference between a diagonal cov MvNormal and a set of normals? Also if MvNormal is the suggested practice, then is there anyway around to couple MvNormal with diagonal cov and censoring?

3- I did run the model above except just changing MvNormals to Normals and when I plot the trace, results I get is nowhere near as good as yours. So against my intuition (and yours) there does seem to be a difference between MvNormal with diag cov and a set of normals. However I can not play around with MvNormal due to optimization being extremely slow. I am still going to leave it on for a while in hopes that it may speed up later.

ps: I have been running this code for a while and it shows no signs of speeding up! I am on a pretty decent computer.

ps2: I have followed the installation instructions at (including numpyro, blackjax and nutpie but will try uninstalling these later):
https://www.pymc.io/projects/docs/en/latest/installation.html
I dont get any warnings when I import pymc

I ran the model on a 2021 macbook pro m1, it took about 10 minutes in total (4 minutes ADVI, 6 minutes sampling). JAX took about 4 minutes in total but I got worse traces.

The fact that MvNormal specifically is giving you grief makes me worry about your BLAS installation. If you’re on an M1/M2 mac make sure you’re using accelerate, otherwise make sure you’re using MKL. If you’re on Linux, slow sampling rates on MvNormal models have been reported on Ubuntu: see here, and the associated issue here.

Ok funny thing, sampling with MV works fine when I use pymc 5.0 (though some of the other suggestions in the links dont make much of a difference, though havent digged too deep). I am going to try to pinpoint what is the latest version that works fine and file a report. While I do that, any ideas why MvNormal with diag covariance and a set of independent normals (which are mathematically equivalent) would generate different results? So when I test your code in pymc 5.0

1- Using the same code as above both sampling with only pm.sample() still results in very multimodal results for both x and y components. So advi does play a role in getting those nice results and I think it was mentioned else where somewhere in this forum that advi might be more appropriate for mixture models.

2- sampling with pm.sample(init=‘advi+adapt_diag’) does result in the unimodal and better results that you have shown. So it was only natural to try pm.sample(init=‘advi+adapt_diag’) with a set of independent univariate normals however the result I get is again as before quite multimodal with a lot of divergences. Setting target_accept=0.99 gets somewhat better results similar to advi but still not as good. But still this suggests me that most of the improvement is coming from advi however I am really suprised that there is really any difference between a MV normal with diag correlation and a set of independent normals. I will investigate this more.

3- sampling with advi and univariate normals but with a single sigma parameter:

sigma = pm.InverseGamma('sigma', alpha=1, beta=2)

and target_accept=0.95 gives almost as good results for x components as MvNormal and the original code (y components still multimodal).

A (PyMC) mixture of MvNormal vs univariate Normals is not equivalent even with unit covariance. Mixture treats the x,y pairs as exchangeable across univariate components (i.e., one x could come from one component and the y from another). This is not the case with mvnormal components where both x and y must come from the same component.

In your model this means that when solving the x labeling problem with the ordering constraint the y is also now well identified. If you use univariate normals, the y variate is still subject to label switching because it is independent from x.

1 Like

A few code snippets to illustrate the point:

import pymc as pm

# Create two (seemingly) equivalent mixtures
cov = [[1, 0], [0, 1]]
w = [0.5, 0.5]
diag_mvn_mix = pm.Mixture.dist(w, comp_dists=[pm.MvNormal.dist([-10, -10], cov), pm.MvNormal.dist([10, 10], cov)])
ind_norm_mix = pm.Mixture.dist(w, comp_dists=[pm.Normal.dist([-10, -10]), pm.Normal.dist([10, 10])])

Each mixture has two components, a very negative and very positive. Within in each component, there are two elements, both with the same mean.

Our naive expectation is the result mixtures should only draw two positive numbers, or two negative numbers. The next code tests the theory:

# Draw samples
diag_mvn_samples, ind_norm_samples = pm.draw([diag_mvn_mix, ind_norm_mix], 100_000)

# Percentage of draws from MvNormal that agree in sign 
(np.all((diag_mvn_samples > 0), axis=1) | np.all((diag_mvn_samples < 0), axis=1)).sum() / 1000
>>> Out: 100.0

# Percentage of draws from independent normals that agree in sign:
(np.all((ind_norm_samples > 0), axis=1) | np.all((ind_norm_samples < 0), axis=1)).sum() / 1000
>>>Out: 49.9

What is happening? The independent normals are being mixed in all dimensions, because the PyMC model doesn’t know how to distinguish support dimension from a core dimension. I was able to grok it by thinking about the generative process. In the MvN case, we flip a coin to choose a distribution then sample from that distribution. In the independent Normal case, we traverse the batch dimensions (all dimensions to the left of the component dimension) and flip a coin for each one. That’s how we end up with mixed signs. Note that this doesn’t happen in the MvN case only because PyMC knows that it’s a multivariate distribution. If we had an additional batch dimension – for example we wanted to sample a 3-tuple of (x, y) coordinates – you would see the same “multi-flipping” behavior in the MvN case.

This also has consequences for logp evaluations, as it means that the non-sorted coordinate (y in your case) will not be able to become “attached” to a single x, and you will observe label switching. Here is an example of my results using independent normals:

image

You can see that the x coordinates are all correct, but where distributions are aligned vertically, there is mode switching in the y dimension. You must use an MvN to prevent this from happening.

1 Like

Also note that MvNormal is not a silver bullet. if you have a lot of vertical overlap across clusters you may still have switching problems.

Depending on the problem you may want to sort y means instead or perhaps use a different space representation (like polar coordinates) that better disambiguates the mean coordinates.

1 Like

Is there any work around where I could have the non-mixing property of MvNormal while using censored univariate normals? If that is not straight forward I will open another topic illustrating as to in what context I need such a model.

You can define your logp manually in a CustomDist if you want to use Censored Normals but treat the x,y pairs as non-exchangeable.

However I am not sure if Censoring will make complete sense in this case, since it will be axis-aligned? What if censoring happens when x+y>b?

And are you sure you are working with censoring and not truncation? Sounds interesting either way.

Do you have an example of code that generates synthetic data?

Thanks, I opened a new topic on this:

However a short reply, since this is MvNormal with diagonal cov normally, its pdf would actually be just a product of two independent normals which could in theory be censored independently. Not sure how to implement it though and if possible by still keeping the non-mixing property of the MvNormal.