I used the pymc3jax branch today and tried sampling a model to see if I got any speedup. I had to turn off theano shared variables (which I was using for all data) to get it work (otherwise getting the cryptic theano MissingInput: Input 0 of the graph... error) but otherwise it worked fine. However, I obtained only a very small speedup:
pymc3 default: 2 hours and 44 minutes
pymc3 JAX: 2 hours and 26 minutes
And since only the former is actually samples using Python code, only the former gave me a progress bar (which to me is worth a 10% performance penalty)! I had been expecting speedups on the order of 7.5x as demonstrated in the example posted by the developers.
Does anyone know for what reasons my speedup might have been so small? Is there something about model specification (or mis-specification) that might give rise to it?
We will be interested to take a look at your set up - for example, which sampler you are using (the numpyro one can give you a progress bar, with some additional penalty to the run time)? are you sampling in CPU or GPU? do you have a lot of for lop in your model?
I used the numpyro sampler with the same argument used in the demo notebook. My code looks just like cells 5 and 6 in that notebook (except I did not add the compute_convergence_checks=False kwarg). I used the CPU (probably dumb, and maybe this is the first thing I should switch; is there a demo for using GPU with these new samplers?). I have no for loops or flow control in model. I’d be happy to share the notebook and data for reproducing it over email (rgerkin at asu dot edu).
I may have some (slightly half-baked) things to add here.
I’ve noticed that the progress bar can be misleading when using numpyro, and I think it might have something to do with the placement of jax.jithere. What I suspect is happening is that the decorator means JAX compiles the entire sampling function, which includes an initial tracing step. That initial step is very fast and so the progress bar appears to progress very quickly. It isn’t until later that JAX actually does the real computation, so the progress bar is actually quite misleading.
One way to avoid this would be to jit compile only the logp function, e.g. at this point. I don’t know if that would be slower, perhaps there is a benefit to compiling the entire _sample function, but this would be good to test, I think.
I have two more comments for running on GPU. First off, and you probably know this already, but for numpyro, you have to explicitly mention that you want to run things on GPU. This means that at the start of the script there has to be:
import numpyro
numpyro.set_platform("gpu")
Secondly, I’ve found that if running on a single GPU, numpyro’s vectorized chain mode is actually much more efficient than the parallel one (used here). As I understand it, if only a single GPU is available, chain_mode='parallel' runs the chains in sequence, one after another, while vectorized runs them all at once. vectorized is marked as experimental, but I asked and it should be reliable.
I hope this is useful. Devs, let me know if I should turn any of these things into issues / pull requests. I think it would be good to add vectorized as an option and, if you agree, to do some benchmarks on where to place the jit annotation.