Pymc3jax: JAX sampler slow on Mac OS

Hi, I am experimenting with the new JAX-based sampler that is currently under development in the pymc3jax channel. I am using this notebook provided by @twiecki: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434

I can run the notebook as it is fine. However, the speed-up I observe is only about 3-fold, instead of the 20-fold speed-up reported by @twiecki. Note that I am running this on Mac OS (so on CPU, not GPU).

I can also observe that a significant time is spent in the last line of the function, where the arviz InferenceData instance is created. From conversations on Twitter, I assume that this is where the heavy lifting takes place, because in this line (az_trace = az.from_dict(posterior=posterior)), the XLA arrays are converted into numpy arrays and only here they are evaluated (see https://jax.readthedocs.io/en/latest/async_dispatch.html for an explanation).

I have no previous experience with JAX, so I might be complete wrong here.

Now, I am trying to use the sampler for a custom, hierarchical model of mine. Here, the sampling breaks down completely, i.e., my Python process gets completely stuck in the last line of the function when the arrays are evaluated. I guess, this is the same effect as in the example notebook, just taken to the extreme, because the model is more complex (it also samples quite slowly in the normal sampler and has some divergences).

Thus, my questions:

  • Is this slow sampling perhaps caused by MAC OS? Do I need to change some setting for JAX, XLA, NumPyro etc.?
  • Anything else that I can try out?
  • Minor side question: Why is the --xla_force_host_platform_device_count set to 100 in the code? Noone has 100 CPUs on their machine, right?

(I’m aware that this very experimental code, so don’t take it as criticism, maybe my observations can help the development.)

3 Likes

Thank you for posting your experience here so it will serve as a reference for future users. The DeviceArray that JAX returns is indeed a future i.e. it is not evaluated immediately. To make sure this is the cause you can take one of the arrays in the trace and call its block_unti_ready() method before converting it to Arviz’s InferenceData object. If you still get a 3x speedup then there’s nothing wrong with your machine. Still, x3 speedup is an amazing!

The device_count in XLA is usually used to make the pmap function “see” the CPU’s core as different devices (otherwise it sees the whole CPU as one device). Not sure why it’s set to 100 here.

2 Likes

IIUC XLA only count physical device, so we use xla_force_host_platform_device_count to split it into virtual device so pmap could work. Setting it to 100 means that you can sample at most 100 chains. There are some explanation from numpyro: https://github.com/pyro-ppl/numpyro/blob/b9bd7305de373fb42c735803c935b99ef8edf09b/numpyro/util.py#L57-L79