Prior Predictive Sampling in a Multilevel Linear Model

I’ve only had a cursory think about what you’ve presented so I’ll chip in some bits that popped up to me while reading (apologies in advance for the long reply, but you started it :joy:).

This is to be expected I would say. Scaling up a model can be difficult because and lead to poorer convergences and poorer sampling. Part of this that may be worth looking into is thinking about evaluating the identifiability of your problem, effectively meaning how well your model details the data generation process / aren’t multiple ways to get the same data results. HMC (I’m assuming you are using NUTS for your sampling) breaks down very loudly which is part of the diagnostic benefit of HMC, and my preliminary intuition is that scaling to include more cross sections is meaning more ways to get the same data results. If this is to blame for your problem (I have the same problem with my work at the moment :sweat_smile:) then regularisation helps (like stricter priors which you mentioned that you did), you can also encode more structure in your approach to the problem to minimise this possibility for degeneracy, or find more information if there is any available (usually the right kind of information, not always just more of the same data).

The following paragraphs where you mention about exploring redundant regions basically fits this identifiability picture, but in case you didn’t know the term then perhaps that helps.

It is true that non centred parameterisations can be nice but it isn’t necessarily that simple unfortunately, I am not 100% on the nitty gritty details but here are some pymc posts worth consulting that helped my understanding on why non centred isn’t a clear cut advantage: Noncentered Parametrizations , Non-centering results in lower effective sample size . But I also found more discussion by just typing in Google “pymc non centred parameterisation” so there is plenty of discussion to be found. I haven’t thought too much about whether changing parameterisation would necessarily help but it is wortha mention.

The bit that you mention about reducing down to focusing on a single cross section and still generating problems, that I am not immediately sure on diagnosing I’ll admit, nor the additional point around being unable to sample prior predictive on the whole model, but like you mention you seem to at least have an initial lead on solving that.

I guess to summarise (1) yes it is to be expected (if you are mainly talking about convergence, sampling time…etc), to quote Thomas Wiecki “Hierarchical models are awesome, tricky and Bayesian” (while my_mcmc:  gently(samples) - Why hierarchical models are awesome, tricky, and Bayesian) and that tricky part of the quote is the real pain since there are a few pain points involved like distribution funnels and even identifiability problems.

(2) I would say have a think about parameterisations (or just trial and error changing parameterisations) to see if you can find benefit. Maybe also think about identifiability if it helps and if it is possible to alter your model for better results? These would also hopefully help (4) which addressing your runtime troubles would be another thing to tackle (will elaborate in paragraph after next)

(3) The r-hat diagnostic is good, it does have some limitations but someone more experienced would perhaps be better placed to elaborate on anything regarding it. The only thing that I’d mention is if you haven’t done it already, r-hat is more robust with doing it with more shorter chains rather than a single longer one.

(4) again I’m assuming that you are using the NUTS sampler. It is worth profiling your code to figure out where these pinch points are, pymc does have profiling features but I’m still not 100% on doing it but there have been a few posts recently about debugging a model and profiling. My initial intuition is that that your model problems are the main source of your grief resulting in large NUTS tree depth, meaning it is spending a lot of time on the Hamiltonian trajectory, you can validate this is the case by indexing the tree_depth part of the sample_stats part of your trace, so something like:

print(trace.sample_stats['tree_depth'])

which would return an array of numbers, so say the result is an array of 10s, then since NUTS uses a binary tree doubling this means 2^10 calls which naturally becomes quite incumbent. More recent versions of pymc give you a warning telling you about hitting max tree_depth and advising to increase tree_depth or reparameterise but I don’t know which version of pymc you have so I don’t know if you would be getting this warning or not :sweat_smile:.

Hope any of this helps, even if it is just aiding understanding of what the problems may be!