Introducing bayeux

Dear all – Happy to release a library for doing inference in JAX. It is ready to go as of ~20 minutes ago with (most) PyMC models. Check out this quickstart colab to jump right in (requires a recent version of PyMC), read the docs here or see the code here.

Currently it surfaces a bunch of ways of doing MCMC from blackjax and numpyro, as well as optimization from optax, jaxopt, and optimistix. There is also a VI routine from TFP that doesn’t quite work with all bayeux models. Coming up will be the VI routines from numpyro, and trying to work out why chees and meads from blackjax aren’t working very well on most models.

Happy to take pull requests, bug reports etc here or elsewhere.

8 Likes

Feature request: Using with PyMC example

2 Likes

Yes! Tuesday morning something similar to the above colab will go up!

Update: It went up Wednesday morning instead! Adding this turned up a bug in bayeux’s structural vi from TFP implementation, but it should now work generally.

Next on the list is trying to add some VI from numpyro.

Probably not (yet) a good idea, but you could delete lots of pymc/sampling/jax.py and replace it with bayeux now :grin: (just don’t delete the jaxify functions that bayeux uses…)

3 Likes

If you want to use this with Bambi just call model.build() and then use model.backend.model in the bx.Model.from_pymc call like this:

dist = pm.Normal.dist(mu=100, sigma=30)

draws = pm.draw(dist, draws=1000, random_seed=1000)

df = pd.DataFrame(data=draws, columns=['heights'])

formula = bmb.Formula('heights ~ 1')

model = bmb.Model(formula=formula, family='gaussian', data=df)

model.build()

bx_model = bx.Model.from_pymc(model.backend.model)

idata = bx_model.mcmc.numpyro_nuts(seed=jax.random.key(0))

az.summary(idata)
3 Likes

Thank you! Would it be ok to add this to the documentation?

Of course!

1 Like

Hi Colin! would this be the best path forward for using JAX + PyMC now?

I also see the PyMC docs Faster Sampling with JAX and Numba but that appears to be from 2022.

I’m a new PyMC user so please forgive me if this is obviously stated somewhere.

For most users that’s enough. bayeux gives you access to more wild/experimental samplers and optimizers but those documented in the docs suffice to 99% of the users.

Nothing wrong with exploring bayeux

2 Likes

Strong agree with @ricardoV94 – if you’re getting started with PyMC, you’ll find the best support, most responsive issue response, fastest bug fixes, and a huge amount of existing utility functions by just using the nuts_sampler="numpyro". Once you call bx.Model.from_pymc(model), your work downstream from there will be using the JAX ecosystem.

It suffices to be aware that you can usually add one line from bayeux and have access to some different samplers (and optimizers and VI).

2 Likes