Speed UP Bambi Fitting

Hi Bambi Folks,

Thanks for making Bayesian inference more accessible! I’m facing an issue where the compute time grows exponentially with data size. I’m running on a Colab Pro plan — does Bambi support GPU acceleration? Any ideas on how to speed it up?

I’m currently working with a sample of 250K rows, and the model has been running for more than 6 hours.

Yes, like all PyMC models, bambi supports GPU via jax samplers – nutpie, numpyro, bayeaux, etc. For example:


data = bmb.load_data("sleepstudy")
model = bmb.Model('Reaction ~ Days', data)

# Option 1: Nutpie
results = model.fit(draws=1000,
                    inference_method='nutpie',
                    nuts_sampler_kwargs={'backend':'jax'})

# Option 2: Numpyro
results = model.fit(draws=1000,
                    inference_method='numpyro',
                    # Options are 'vectorized', 'parallel', or 'sequential'
                    nuts_sampler_kwargs={'chain_method':'vectorized'})

There are a ton more jax backends via bayeux/blackjax, check out the docs for details on all that. (I don’t see bayeux stuff specifically, maybe @tcapretto can point you to a better place to look)

Thank you very much — this is extremely helpful. I have one more question, if I may: is it possible to run simulations in Bambi? Does it support answering “what if” questions?

@jessegrabowski

You can always extract the pymc model inside the bambi model with pm_model = bambi_model.backend.model. One you have that, you can do anything you want with it. PyMC offers pm.do to do graph interventions. See here and here for information about pm.do, and applied examples.

You could also explore the tools available in the bambi.interpret module, this is one entry point to that Plot Conditional Adjusted Predictions – Bambi.