Speed UP Bambi Fitting

Hi Bambi Folks,

Thanks for making Bayesian inference more accessible! I’m facing an issue where the compute time grows exponentially with data size. I’m running on a Colab Pro plan — does Bambi support GPU acceleration? Any ideas on how to speed it up?

I’m currently working with a sample of 250K rows, and the model has been running for more than 6 hours.

2 Likes

Yes, like all PyMC models, bambi supports GPU via jax samplers – nutpie, numpyro, bayeaux, etc. For example:


data = bmb.load_data("sleepstudy")
model = bmb.Model('Reaction ~ Days', data)

# Option 1: Nutpie
results = model.fit(draws=1000,
                    inference_method='nutpie',
                    nuts_sampler_kwargs={'backend':'jax'})

# Option 2: Numpyro
results = model.fit(draws=1000,
                    inference_method='numpyro',
                    # Options are 'vectorized', 'parallel', or 'sequential'
                    nuts_sampler_kwargs={'chain_method':'vectorized'})

There are a ton more jax backends via bayeux/blackjax, check out the docs for details on all that. (I don’t see bayeux stuff specifically, maybe @tcapretto can point you to a better place to look)

2 Likes

Thank you very much — this is extremely helpful. I have one more question, if I may: is it possible to run simulations in Bambi? Does it support answering “what if” questions?

@jessegrabowski

You can always extract the pymc model inside the bambi model with pm_model = bambi_model.backend.model. One you have that, you can do anything you want with it. PyMC offers pm.do to do graph interventions. See here and here for information about pm.do, and applied examples.

1 Like

You could also explore the tools available in the bambi.interpret module, this is one entry point to that Plot Conditional Adjusted Predictions – Bambi.