Hi, I am experimenting with the new JAX-based sampler that is currently under development in the pymc3jax
channel. I am using this notebook provided by @twiecki: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434
I can run the notebook as it is fine. However, the speed-up I observe is only about 3-fold, instead of the 20-fold speed-up reported by @twiecki. Note that I am running this on Mac OS (so on CPU, not GPU).
I can also observe that a significant time is spent in the last line of the function, where the arviz InferenceData instance is created. From conversations on Twitter, I assume that this is where the heavy lifting takes place, because in this line (az_trace = az.from_dict(posterior=posterior)
), the XLA arrays are converted into numpy arrays and only here they are evaluated (see https://jax.readthedocs.io/en/latest/async_dispatch.html for an explanation).
I have no previous experience with JAX, so I might be complete wrong here.
Now, I am trying to use the sampler for a custom, hierarchical model of mine. Here, the sampling breaks down completely, i.e., my Python process gets completely stuck in the last line of the function when the arrays are evaluated. I guess, this is the same effect as in the example notebook, just taken to the extreme, because the model is more complex (it also samples quite slowly in the normal sampler and has some divergences).
Thus, my questions:
- Is this slow sampling perhaps caused by MAC OS? Do I need to change some setting for JAX, XLA, NumPyro etc.?
- Anything else that I can try out?
- Minor side question: Why is the
--xla_force_host_platform_device_count
set to 100 in the code? Noone has 100 CPUs on their machine, right?
(I’m aware that this very experimental code, so don’t take it as criticism, maybe my observations can help the development.)