Problems with Hierarchical Normal Mixture Model

Hi! I am using a Hierachical Mixture model (trying the linear local Dirichlet model described here, see page 10). I have xdata in the form of an array with entries (N, Z), and about 2000 datapoints. The model is defined as follows:

mixture_model = pm.Model()

with mixture_model:
    sigma = pm.HalfNormal("sigma", sigma=10, shape=n_models) # Tried gamma, gamma slow. HalfNormal with 10 -- representative of approx. rms
    beta = pm.Normal("beta", mu=0, sigma=1, shape=(n_models, 2))
    x = pm.Data("x", xdata)
    logalpha = pm.math.matmul(beta, x.T)
    weights = pm.Dirichlet("w", pm.math.exp(logalpha.T)) #, a=np.ones(2))#


    y_obs = pm.NormalMixture("y", w=weights, mu=mu, sigma=sigma, observed=BEA, shape=1)

The problem I am having is an extremely slow sampling time, without very good results either. Running ~1000 draws and 1000 tuning for four chains takes about 1-2 hours, and running more, say 25 000 draws and 25 000 tuning, takes about 24 hours, but returns and ESS of ~4, and rhat >~ 3. I should also point out that I have managed to remove any divergences by adjusting the prior distributions of sigma to the reasonable sigma=10, but nonetheless the sampling is very slow.

I realize that the 4000 weight parameters here are a lot, but since the paper I am following must have more, as they are using multiple models, I am wondering whether I might have made some mistake.

I am thankful for any help that you can give. Please let me know if any additional information is needed.

1 Like

What is the mu that goes into NormalMixture? It does not seem to be defined in the model. Also how many clusters are we talking about? Is it really 4000 or did you broadcast your weights into some shape compatible with the data and that is why it is so big? Cant tell what is going on in your dirichlet definition cause I dont know the shape of logalpha but for instance if you have 2000 datapoints and you want to define two clusters per point with possibly weights depending on data points that should not be a major cause of slow down. That would mean your weights will have shape 2000,2 (I think) which is what a weight of shape (2,) would have been broadcasted to too. Perhaps you can write the shape of your observed, mu and weights for us to get a better idea (or better yet post a complete code).

In my experience for instance when running about 200 data points, 9 clusters on a six dimensional problem with ~3000 draws and tune can take like up to an hour on a pretty decent computer. Is your problem 1D? Never worked with 1D problems much but maybe it is more tame there.

Also normal mixtures are very prone to mixing so you need to use tricks like ordering (and for higher dimensional problems, on a possibly choice of coordinates more suited for your problem), see for instance:

Also to possibly understand whether or not it is a problem with your data not being compatible with the model or not, you can maybe generate data using sample_prior_predictive and use the data generated to train the model and see if it runs faster. See: