Get maximum likelihood of a variable across chains

Is there a way to get the most likely value of a parameter across chains and draws? My model appears to be multimodal and I am interested in extracting the value that was best across MCMC Mixture sampling.

Maximum posteriori doesn’t sound like a solution to detect multimodality either.

Nothing immediate jumps to mind, usually people just look at the marginals or pair plots and try to see if there’s multimodality.

You could perhaps fit some nonparametric mixture model on the posterior draws

What about using SMC instead of NUTS? The NUTS trace output appears to gets stuck in a mode and throws max_treedepth warnings. I have read some of the other threads about max tree depth and have already made sure all my parameters are Normals and HalfNormals with realistic values.

If NUTS is having trouble, SMC is unlikely to help. When NUTS struggles, it suggests that the geometry of your posterior is difficult in some sampling-algortihm-agnostic manner. SMC might complain less, but it is likely to encounter slightly different kinds of issues but for the same underlying reasons.

I think @ajc is suggesting SMC because the OP mentioned multimodality. SMC can be better equipped to handle multimodality than NUTS. But SMC can struggle in high dimensionality or complex geometries like the ones induced by hierarchical models, at least the SMC version implemented in PyMC because it does not use gradients

1 Like

@aloctavodia Exactly, although I am unsure if I am using multimodality correctly. NUTS previously would stall often and the trace would look almost like it had delta functions. I could improve it by increasing “max tree depth” but it would have to go to over 20 (prohibitively long runtime). Attached is a trace image of SMC sampler for one of my variables:
So I guess I have two questions:

  1. Does this seem like a reasonable use case for SMC? ESS and Rhat both seem reasonable to me and the posterior trace seems to match up very well with the data. So I am assuming yes!
  2. On the trace above, the “mean” value from az.summary returns a value of 5.216 which doesn’t seem to correspond with one of the maximum “peaks” in the trace. So, is this “mean” value the one I should use as what the optimum parameter value is, or should I be looking for the value that corresponds to the highest peak in the trace (and how would I extract that)?

I’m not sure you can trust ESS and Rhat metrics from SMC samples?

pymc-experimental has a function to sampling using the blackjax SMC sampler that does use gradients (it can do either HMC or NUTS in the inner loop), but it’s essentially undocumented. Help wanted!

Is there a benefit to doing SMC with gradients?

In principle it should give you better local exploration during the population mutation step. But it depends on the problem specifics.

How many parameters does your model have?

I believe 5 (although not sure if the weights from Dirichlet are actually computed as 1 or 2 ).

Yes this should be a good candidate for SMC.

To actually answer your original question, you can use pm.compute_log_likelihood to compute the log-likelihood of each draw under the posterior. You could then rank the draws by their likelihoods.

I’m not totally sure why you would want to do this – the whole point of Bayes is to get the entire posterior. Although, as you correctly point out, the mean isn’t representative of any given mode, it is representitve all the modes, weighted by their respective probabilities. If you’re only interested in the mode, you can save a lot of compute by just using pm.find_MAP.

Do you have a scientific reason to expect such a strongly multi-modal posterior? It looks highly unusual.

The underlying data set is biological data that is right tail heavy. The tail population also behaves differently so we fit a gaussian and a shifted exponential to each of these populations. My understanding from reading is that the nature of a gaussian-shifted exponential mixture model supports multiple solutions (especially for the shift and the lambda parameters) that will fit the data well.

You should not trust ESS from SMC. By desgin ESS from SMC will be always high, unless your results are so bad that it will be evident from a traceplot/rankplot that you have bad results.
R-hat is informative for SMC rank plots too.