I have noticed that when sampling with nutpie, the sampler will be very quick to finish a single chain, and the remaining take considerably longer. Is this expected behavior?
I can think of two reasons for this:
The posterior geometry is very problematic, and the chains get stuck in different regions of the posterior, where the number of gradients per draw and step size is very different.
Or this is a threading issue, where for some reason one chain gets much more CPU time. Things like this can sometimes happen with the jax backend, and I don’t really know what to do about that. I’ve never seen this with numba.
Given that you have very small stepsizes, lots of divergences and very different numbers of gradient evaluations per draw, my guess would be the first one. If so, you’ll have to work on your model a bit, and either change or reparametrize it.
re: threading in JAX, setting BLAS flags as described in this blog post helps me immensely.
But I agree that if the stepsize is 0.0, it’s probably a geometry problem.
Now I’m confused
I thought jax didn’t use blas but relied on eigen?
I appreciate the response, I think this is probably a threading issue as I am able to recreate it with a simpler model. I will look into that BLAS flags solution. Thanks again for your help!
This page suggests that Eigen can be configured to call out to BLAS/LAPACK in certain cases. I can 100% report that setting the threads matters, so maybe this is happening?