I am currently implementing a hierarchical linear model (with correlated slopes and intercepts) that includes various random effects. I am using a single predictor. The dataset is huge, nearly 300k samples. When I try to sample (4 chains, 1000 samples, 1000 tuning iterations), the sampler takes a really long time, more than the 14 hours of runtime that I have available. I have tried to reduce the number of samples (~140k samples) and it takes about 12 hours.
I am using a non-centred parametrization of the multivariate normal to model the covariance between slopes and intercepts. When I manage to get samples using the small dataset, I get some divergences (~40 per chain), but by doing pair plots they don’t seem to be caused by funnels.
I am wondering if the long sampling time is caused by a problem in my model (the linear relationship is really small for some groups) or by the size of the dataset.
It sounds like the dataset size is the simplest explanation. You might try sampling with
sample_numpyro_nuts, particularly if you can get your hands on a GPU (e.g. on Colab). I get significant speed gains for my larger models. Another solution is to use ADVI, for which your milage may vary in terms of the quality of the approximation, but the runtime will be far shorter with variational inference.
Hi @fonnesbeck, thanks for the reply!
I tried to use
sample_numpyro_nuts, installing Jax and numpyro. However, I get the following exception:
Exception: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher. When I use an earlier version of Jax, then I get
AttributeError: module 'jaxlib.xla_client' has no attribute 'get_local_backend'. Apparently other people have the same problem (here).
May I ask you how you get numpyro nuts to work? Thank you and sorry if this is slightly distant from the original question.
@fagiolino We no longer support
sample_numpyro_nuts in PyMC3 V3.xx, but only in the upcoming PyMC V4 major release. You can install the beta version to try with
Yes, I strongly recommend running the v4 beta:
pip install -U pymc --pre
I’ve had no issues with
sample_numpyro_nuts in v4. so hopefully this will fork for you.