Prior Predictive Sampling in a Multilevel Linear Model

Hi all,

Please excuse the long post.

I am currently working on a multi-level linear model for a marketing mix prediction where there are four levels, with each level containing between 2 and 7 groups. For example, the State level has 7 groups (South Australia, Western Australia,…), and the Brand level has two groups (Brand X, Brand Y), and so on.

I have modelled the design by extending both the example notebook here and the example notebook here.

All of my variables are non-centered as suggested is best practice.

The X data is scaled by maximum values within each cross section, and the target values are shifted by the minimum and scaled by the range of each cross section.

The coordinates of the model are:

coords = {
    "date": dates,
    "channel": channels,
    "pos_control": pos_controls,
    "neg_control": neg_controls,
    "lag": lags,
    "state": mn_state,
    "age": mn_age,
    "brand": mn_brand,
    "cohort": mn_cohort,
}

where

date - has length - 156
channel - has length - 23
pos_control - has length - 19
neg_control - has length - 4
lag - has length - 1
state - has length - 7
age - has length - 6
brand - has length - 2
cohort - has length - 3

The model is essentially:

    # Model error
    sd_y = pm.Exponential("sd_y", 1)

    y_hat = pm.Deterministic(
        "y_hat",
        var=alpha[np.newaxis, :, :, :, :]
        + channel_contribution.sum(axis=1)
        + pos_control_contribution.sum(axis=1)
        + neg_control_contribution.sum(axis=1)
        + lag_contribution.sum(axis=1),
        dims=("date", "state", "age", "brand", "cohort"),
    )

    # Data likelihood
    y_like = pm.Normal(
        "y_like",
        mu=y_hat,
        sigma=sd_y[np.newaxis, ...],
        observed=target_value,
        dims=("date", "state", "age", "brand", "cohort"),
    )

I will define channel_contribution below, because the other components of mu are analogous.

    # Slopes
    mu_b = pm.HalfNormal("mu_b", sigma=1, dims=("channel"))

    # Non-centered random slopes - state level
    z_b_state = pm.HalfNormal("z_b_state", sigma=1, dims=("channel", "state"))
    sigma_b_state = pm.Exponential("sigma_b_state", 1)

    # Non-centered random slopes - age level
    z_b_age = pm.HalfNormal("z_b_age", sigma=1, dims=("channel", "age"))
    sigma_b_age = pm.Exponential("sigma_b_age", 1)

    # Non-centered random slopes - brand level
    z_b_brand = pm.HalfNormal("z_b_brand", sigma=1, dims=("channel", "brand"))
    sigma_b_brand = pm.Exponential("sigma_b_brand", 1)

    # Non-centered random slopes - cohort level
    z_b_cohort = pm.HalfNormal("z_b_cohort", sigma=1, dims=("channel", "cohort"))
    sigma_b_cohort = pm.Exponential("sigma_b_cohort", 1)

    betas = pm.Deterministic(
        "betas",
        mu_b[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis]
        + z_b_state[:, :, np.newaxis, np.newaxis, np.newaxis] * sigma_b_state
        + z_b_age[:, np.newaxis, :, np.newaxis, np.newaxis] * sigma_b_age
        + z_b_brand[:, np.newaxis, np.newaxis, :, np.newaxis] * sigma_b_brand
        + z_b_cohort[:, np.newaxis, np.newaxis, np.newaxis, :] * sigma_b_cohort,
        dims=("channel", "state", "age", "brand", "cohort"),
    )

    channel_contribution = pm.Deterministic(
        "channel_contribution",
        var=betas * _saturated,
        dims=("date", "channel", "state", "age", "brand", "cohort"),
    )

Clearly, channel_contribution is non-negative (as it needs to be for media contributions). Further, _saturated is the media-data with an adstock and saturation function (used directly from pymc-marketing) applied to it.

alpha, pos_control_contribution, lag_contribution are all defined in a similar way (all weakly positive), and unsurprisingly, neg_control_contribution is weakly negative.

So now I get to the crux of it.

When I run this model on various specific cross sections of data at a target accept of 80%, I usually get 1-2 divergences, and all of my r_hat values are 1, and I sample in about a minute. This seems good.

Then, increasing the cross sections included (expanding across the levels, for example including all ages, then all brands, then all cohorts, then all states, in no particular order), seems to work ok up to a point (I can’t say exactly what this critical mass is), and then r_hat issues start to appear, but still very few divergences. The model samples very slowly (36 hours for the full model), and the r_hat values become very, very poor with 90% of the variables estimated showing r_hat > 1.01.

So my thinking here was that the sampler must be exploring redundant regions that are not close to reality before getting to where it needs to go, and that is why it isn’t confident in the estimates. To check this, I run pm.sample_prior_predictive on a specific cross section, and lo and behold I get a plot that looks like this:

Clearly the prior predictions lay above the target value. I think this must be because of the positivity constraints I have on my variables and the priors on those being too large? Interestingly, the prior predictive samples actually match the data pretty well if I scaled the target variable up. Despite this, as I said before, when I run the model on this specific cross section, I get no divergences and good r_hats, and I am happy with the predictions for a first pass.

So I try to tune the priors down a bit, so they can begin at a more reasonable area, and I do this by reducing the sigma values for the half normal distributions in the model to 0.01 (I admit this is quite lazy). I get to a point which seems to cover the target value quite well:

But then actually sampling even the single cross section model produces r_hat issues (3% of variables have r_hat >1.01). Expanding across a hierarchy (state for example) produces poorer r_hat values than the wider priors.

(I am unable to run pm.sample_prior_predictive on the full model without getting an error of ValueError: conflicting sizes for dimension 'state': length 1 on the data but length 7 on coordinate 'state', which seems to be similar to the issue faced in this post, but that isn’t the point of my question.)

So I would like to know:

  1. Is this behaviour expected?
  2. What should I try next? (Scaling the y data up doesn’t seem to help)
  3. Are there some gaps in my understanding about the r_hat reasons?
  4. Are there any suggestions for speeding up the sampler?

Thank you for reading all of that. I would be super grateful for any pointers.

I’ve only had a cursory think about what you’ve presented so I’ll chip in some bits that popped up to me while reading (apologies in advance for the long reply, but you started it :joy:).

This is to be expected I would say. Scaling up a model can be difficult because and lead to poorer convergences and poorer sampling. Part of this that may be worth looking into is thinking about evaluating the identifiability of your problem, effectively meaning how well your model details the data generation process / aren’t multiple ways to get the same data results. HMC (I’m assuming you are using NUTS for your sampling) breaks down very loudly which is part of the diagnostic benefit of HMC, and my preliminary intuition is that scaling to include more cross sections is meaning more ways to get the same data results. If this is to blame for your problem (I have the same problem with my work at the moment :sweat_smile:) then regularisation helps (like stricter priors which you mentioned that you did), you can also encode more structure in your approach to the problem to minimise this possibility for degeneracy, or find more information if there is any available (usually the right kind of information, not always just more of the same data).

The following paragraphs where you mention about exploring redundant regions basically fits this identifiability picture, but in case you didn’t know the term then perhaps that helps.

It is true that non centred parameterisations can be nice but it isn’t necessarily that simple unfortunately, I am not 100% on the nitty gritty details but here are some pymc posts worth consulting that helped my understanding on why non centred isn’t a clear cut advantage: Noncentered Parametrizations , Non-centering results in lower effective sample size . But I also found more discussion by just typing in Google “pymc non centred parameterisation” so there is plenty of discussion to be found. I haven’t thought too much about whether changing parameterisation would necessarily help but it is wortha mention.

The bit that you mention about reducing down to focusing on a single cross section and still generating problems, that I am not immediately sure on diagnosing I’ll admit, nor the additional point around being unable to sample prior predictive on the whole model, but like you mention you seem to at least have an initial lead on solving that.

I guess to summarise (1) yes it is to be expected (if you are mainly talking about convergence, sampling time…etc), to quote Thomas Wiecki “Hierarchical models are awesome, tricky and Bayesian” (while my_mcmc:  gently(samples) - Why hierarchical models are awesome, tricky, and Bayesian) and that tricky part of the quote is the real pain since there are a few pain points involved like distribution funnels and even identifiability problems.

(2) I would say have a think about parameterisations (or just trial and error changing parameterisations) to see if you can find benefit. Maybe also think about identifiability if it helps and if it is possible to alter your model for better results? These would also hopefully help (4) which addressing your runtime troubles would be another thing to tackle (will elaborate in paragraph after next)

(3) The r-hat diagnostic is good, it does have some limitations but someone more experienced would perhaps be better placed to elaborate on anything regarding it. The only thing that I’d mention is if you haven’t done it already, r-hat is more robust with doing it with more shorter chains rather than a single longer one.

(4) again I’m assuming that you are using the NUTS sampler. It is worth profiling your code to figure out where these pinch points are, pymc does have profiling features but I’m still not 100% on doing it but there have been a few posts recently about debugging a model and profiling. My initial intuition is that that your model problems are the main source of your grief resulting in large NUTS tree depth, meaning it is spending a lot of time on the Hamiltonian trajectory, you can validate this is the case by indexing the tree_depth part of the sample_stats part of your trace, so something like:

print(trace.sample_stats['tree_depth'])

which would return an array of numbers, so say the result is an array of 10s, then since NUTS uses a binary tree doubling this means 2^10 calls which naturally becomes quite incumbent. More recent versions of pymc give you a warning telling you about hitting max tree_depth and advising to increase tree_depth or reparameterise but I don’t know which version of pymc you have so I don’t know if you would be getting this warning or not :sweat_smile:.

Hope any of this helps, even if it is just aiding understanding of what the problems may be!

Thanks for all of that @brandonhorsley, super insightful.

I have been using numpyro instead of NUTS because it seems to be faster.

I have tightened up the priors on my model and also reparametrised the exponential scale parameters to be half normals. This has cut the sampling time down to 8 hours which is a huge improvement, and the r_hat values are looking better. Not where I want them yet, but much better. I think my next step is to pass more specific priors, rather than a general prior for each variable. Some variables I know to have a much stronger effect than others.

Perhaps I can run the NUTS sampler on a cross section and see where the pinch points are, as you say.

Thanks again.

1 Like