I may have some (slightly half-baked) things to add here.
I’ve noticed that the progress bar can be misleading when using
numpyro, and I think it might have something to do with the placement of
jax.jit here. What I suspect is happening is that the decorator means JAX compiles the entire sampling function, which includes an initial tracing step. That initial step is very fast and so the progress bar appears to progress very quickly. It isn’t until later that JAX actually does the real computation, so the progress bar is actually quite misleading.
One way to avoid this would be to
jit compile only the
logp function, e.g. at this point. I don’t know if that would be slower, perhaps there is a benefit to compiling the entire
_sample function, but this would be good to test, I think.
I have two more comments for running on GPU. First off, and you probably know this already, but for numpyro, you have to explicitly mention that you want to run things on GPU. This means that at the start of the script there has to be:
Secondly, I’ve found that if running on a single GPU, numpyro’s
vectorized chain mode is actually much more efficient than the
parallel one (used here). As I understand it, if only a single GPU is available,
chain_mode='parallel' runs the chains in sequence, one after another, while
vectorized runs them all at once.
vectorized is marked as experimental, but I asked and it should be reliable.
I hope this is useful. Devs, let me know if I should turn any of these things into issues / pull requests. I think it would be good to add
vectorized as an option and, if you agree, to do some benchmarks on where to place the