I may have some (slightly half-baked) things to add here.
I’ve noticed that the progress bar can be misleading when using numpyro
, and I think it might have something to do with the placement of jax.jit
here. What I suspect is happening is that the decorator means JAX compiles the entire sampling function, which includes an initial tracing step. That initial step is very fast and so the progress bar appears to progress very quickly. It isn’t until later that JAX actually does the real computation, so the progress bar is actually quite misleading.
One way to avoid this would be to jit
compile only the logp
function, e.g. at this point. I don’t know if that would be slower, perhaps there is a benefit to compiling the entire _sample
function, but this would be good to test, I think.
I have two more comments for running on GPU. First off, and you probably know this already, but for numpyro, you have to explicitly mention that you want to run things on GPU. This means that at the start of the script there has to be:
import numpyro
numpyro.set_platform("gpu")
Secondly, I’ve found that if running on a single GPU, numpyro’s vectorized
chain mode is actually much more efficient than the parallel
one (used here). As I understand it, if only a single GPU is available, chain_mode='parallel'
runs the chains in sequence, one after another, while vectorized
runs them all at once. vectorized
is marked as experimental, but I asked and it should be reliable.
I hope this is useful. Devs, let me know if I should turn any of these things into issues / pull requests. I think it would be good to add vectorized
as an option and, if you agree, to do some benchmarks on where to place the jit
annotation.