Is multicollinearity a problem when fitting a regression model using ADVI?

If I’m fitting a bayesian regression model using ADVI, is it important to ensure all the covariates are uncorrelated with each other? I have a vague understanding that ADVI doesn’t play well with correlations in the posterior, and my understanding is that multicollinearity should produce correlation between the posterior estimates of the coefficients. I’ve included the graphviz output of my model in case it’s helpful.

Hi elewer,

Sorry to answer so late, I hope this is still of interest to you!

I think you are right to be concerned. Mean-field ADVI cannot capture posterior correlations, and I think that multi-collinearity should give rise to correlations in the posterior, as you say.

What should end up happening, at least in a simple model (and yours doesn’t look too bad!), is that ADVI will get the posterior means approximately right, but it will likely underestimate the posterior variance, and it will completely fail to represent the posterior correlation. Taken together, this means that when you predict, you’re likely to get a good mean for the prediction, but a poor variance estimate.

If this is a concern, you could try full-rank ADVI, but I’m not sure I’d recommend it; it hasn’t worked that well for me in the past, and it can be hard to tell when it converges. May I ask why you’re considering ADVI in the first place – is NUTS too slow? If not, I’d go with that.

Hope this is helpful – let me know!

Best,
Martin

NUTS is unacceptably slow – my model takes around 30 minutes to sample with NUTS, whereas I need it to run in under a minute for the application in question. It sounds like the JAX integration might help a lot, but it didn’t seem to speed up my model when I tried it (probably didn’t do it correctly).

I see… I’m not sure I have a great answer (maybe someone else has?), but here are some suggestions:

  • If you’d be happy to post the model code, we might be able to try to make it more efficient and make NUTS more viable that way
  • NUTS with JAX on a GPU may be worth a go.
  • ADVI’s bad (co-)variance estimates may not be the end of the world. If you’re just trying to make a good prediction, they may be good enough. One idea could be to fit both NUTS and ADVI on the data and compare the posteriors and predictions, then decide whether they’re good enough. One tip here: if you try ADVI, maybe bump up the number of iterations from the default – I’ve found 100,000 to be needed quite often, but you can check this by looking at the loss (see Checking Convergence here)
  • You could have a look whether more specialised software, like lme4, statsmodels, or INLA might support your model. These aren’t as flexible as pymc3, but they are optimised for certain classes of models and might be faster for those.

Yes, we’ve seen many models that were “unacceptably slow” with NUTS but in 99% of the cases that slow-down is because of challenging posterior geometry and model reparameterizations and better priors made the model sample quickly.

I would try with Bambi which has some tricks to avoid common pitfalls like co-linearities.

1 Like