Feasibility of large hierarchical multiple regression

I’m fairly new to Bayesian modeling but, I’m hoping to incorporate more of this into my data analysis routine. I have a big dataset from an experimental neuroscience study with multiple subjects and time-series data.

For this specific example, I have a 34081x32 design matrix that contains samples across 8 subjects. To fit this model, I was hoping to formulate this as a hierarchical model with random slopes and intercepts. My model is:

coords = {
    'Mice' : pd.Series(mice),
    'obs_id': x_copy.index.to_numpy(),
    'predictors': pd.Series((x_copy.columns))
}

with pm.Model(coords=coords) as model:
    data = pm.MutableData('x', x_copy, dims=("obs_id", "predictors"))
    g = pm.MutableData("group", group_idx, dims="obs_id")

    # Hyperpriors
    mu_a = pm.Normal('mu_a', mu=0., sigma=10)
    sigma_a = pm.HalfNormal('sigma_a', 5.)
    mu_b = pm.Normal('mu_b', mu=0., sigma=10)
    sigma_b = pm.HalfNormal('sigma_b', 5.)
    sigma_hyper = pm.HalfNormal('sigma_hyper', 5.)

    # Non-centered Reparameterization for group-level parameters
    z_betas = pm.Normal('z_betas', mu=0., sigma=1., dims=("Mice", "predictors"))
    z_intercept = pm.Normal('z_intercept', mu=0., sigma=1., dims="Mice")
    betas = pm.Deterministic('a', mu_a + z_betas * sigma_a)
    intercept = pm.Deterministic('b', mu_b + z_intercept * sigma_b)

    # Model error
    sigma = pm.HalfNormal('sigma', sigma=sigma_hyper, dims="obs_id")

    # Define likelihood
    mu = pm.Deterministic('mu', intercept[g] + (data * betas[g]).sum(axis=-1))
    likelihood = pm.Normal('y', mu=mu, sigma=sigma, observed=y, dims="obs_id")

I used non-centered re-parameterization and ADVI to initialize the model before using NUTS sampling, but because of the complex nature of the model, this is taking an extremely long time to properly get enough samples. So, I was hoping I could get some input from more seasoned Bayesians. Do I abandon some things I want in my model? Do I put this on a super computing cluster and block off a week to sample?

All help is appreciated. I apologize if anything looks weird. Like I said, I’m new to this.

I assume your convergence isn’t great. Do you get divergences?
With many data points, often the centered parameterization works better. Also, play around with your priors a bit, they look like the might be too wide.

Once you have good convergence and it’s still too slow, you can experiment with other samplers: Faster Sampling with JAX and Numba — PyMC example gallery

You have a large but not a huge amount of data, we’ve run models on a lot more data in a few minutes.

1 Like

Hi Thomas, thanks for the help. Interestingly, I was getting lots of divergences when I was using the centered parameterization, and switching to the non-centered seemed to fix that problem, but it was still awfully slow.

i think my big issue is it goes from taking a reasonable amount of time (say 2-3 hours) for the first few thousand samples and then will jump to something absurd (like 100 hours and will still seem to increase as the sampling goes on). I’m still learning the library so I’m still trying to figure out the best way to troubleshoot.

I’ll try using a different sampler per your suggestion and hopefully that will help fix things. Thanks again!

The way you’ve written it, you’re pooling information across predictors, not just across mice. Is that intentional? A priori it doesn’t seem right to me – why would knowing the effect of feature_1 tell me anything about what to expect for the effect of feature_6? On the other hand, knowing the effect of feature_1 on mouse_3 does give me information about what to expect for feature_1 when I look at mouse_12.

Also I would say the sampling time is definitely too long for a model like this. Are you getting any warnings about using numpy BLAS when you import PyMC?

Hmm no that was not intentional–good catch! I’m not getting any warnings either.