Introducing bayeux

Dear all – Happy to release a library for doing inference in JAX. It is ready to go as of ~20 minutes ago with (most) PyMC models. Check out this quickstart colab to jump right in (requires a recent version of PyMC), read the docs here or see the code here.

Currently it surfaces a bunch of ways of doing MCMC from blackjax and numpyro, as well as optimization from optax, jaxopt, and optimistix. There is also a VI routine from TFP that doesn’t quite work with all bayeux models. Coming up will be the VI routines from numpyro, and trying to work out why chees and meads from blackjax aren’t working very well on most models.

Happy to take pull requests, bug reports etc here or elsewhere.

7 Likes

Feature request: Using with PyMC example

2 Likes

Yes! Tuesday morning something similar to the above colab will go up!

Update: It went up Wednesday morning instead! Adding this turned up a bug in bayeux’s structural vi from TFP implementation, but it should now work generally.

Next on the list is trying to add some VI from numpyro.

Probably not (yet) a good idea, but you could delete lots of pymc/sampling/jax.py and replace it with bayeux now :grin: (just don’t delete the jaxify functions that bayeux uses…)

2 Likes

If you want to use this with Bambi just call model.build() and then use model.backend.model in the bx.Model.from_pymc call like this:

dist = pm.Normal.dist(mu=100, sigma=30)

draws = pm.draw(dist, draws=1000, random_seed=1000)

df = pd.DataFrame(data=draws, columns=['heights'])

formula = bmb.Formula('heights ~ 1')

model = bmb.Model(formula=formula, family='gaussian', data=df)

model.build()

bx_model = bx.Model.from_pymc(model.backend.model)

idata = bx_model.mcmc.numpyro_nuts(seed=jax.random.key(0))

az.summary(idata)
3 Likes

Thank you! Would it be ok to add this to the documentation?

Of course!

1 Like