I see… I’m not sure I have a great answer (maybe someone else has?), but here are some suggestions:
- If you’d be happy to post the model code, we might be able to try to make it more efficient and make NUTS more viable that way
- NUTS with JAX on a GPU may be worth a go.
- ADVI’s bad (co-)variance estimates may not be the end of the world. If you’re just trying to make a good prediction, they may be good enough. One idea could be to fit both NUTS and ADVI on the data and compare the posteriors and predictions, then decide whether they’re good enough. One tip here: if you try ADVI, maybe bump up the number of iterations from the default – I’ve found 100,000 to be needed quite often, but you can check this by looking at the loss (see Checking Convergence here)
- You could have a look whether more specialised software, like
lme4,statsmodels, orINLAmight support your model. These aren’t as flexible as pymc3, but they are optimised for certain classes of models and might be faster for those.