Dear all – Happy to release a library for doing inference in JAX. It is ready to go as of ~20 minutes ago with (most) PyMC models. Check out this quickstart colab to jump right in (requires a recent version of PyMC), read the docs here or see the code here.
Currently it surfaces a bunch of ways of doing MCMC from blackjax and numpyro, as well as optimization from optax, jaxopt, and optimistix. There is also a VI routine from TFP that doesn’t quite work with all bayeux models. Coming up will be the VI routines from numpyro, and trying to work out why chees and meads from blackjax aren’t working very well on most models.
Happy to take pull requests, bug reports etc here or elsewhere.
Update: It went up Wednesday morning instead! Adding this turned up a bug in bayeux’s structural vi from TFP implementation, but it should now work generally.
Next on the list is trying to add some VI from numpyro.
Probably not (yet) a good idea, but you could delete lots of pymc/sampling/jax.py and replace it with bayeux now (just don’t delete the jaxify functions that bayeux uses…)
For most users that’s enough. bayeux gives you access to more wild/experimental samplers and optimizers but those documented in the docs suffice to 99% of the users.
Strong agree with @ricardoV94 – if you’re getting started with PyMC, you’ll find the best support, most responsive issue response, fastest bug fixes, and a huge amount of existing utility functions by just using the nuts_sampler="numpyro". Once you call bx.Model.from_pymc(model), your work downstream from there will be using the JAX ecosystem.
It suffices to be aware that you can usually add one line from bayeux and have access to some different samplers (and optimizers and VI).