Bambi is faster than my pure pymc implementation; what did I do wrong?

I’m learning pymc and bambi (coming from R/lme4), and for a starter I implemented a simple multi-level model in all three framworks: lme4, bambi and pymc.

The code runs as it should on colab, available here, but surprisingly, bambi is faster then my pure pymc implementation. The speed I get with pymc is the same with and without GPU (I have created a CPU only version and run it on a Colab instance without GPU), so it seems I failed to get pymc to run on the GPU on Colab. Any hints on how to get jax to use the GPU on Colab are appreciated.

The notebook on Colab is rather lengthy, so I’ll paste the important stuff here:

The bambi version:

my_model = bmb.Model("y ~ 1 + (1 | x)", data = df, family = "bernoulli", dropna=True)
my_fit = my_model.fit(random_seed=1234, tune=2000, draws=2000, target_accept=0.95, method = "nuts_numpyro", chains = 1, chain_method="sequential")
Modeling the probability that y==1
/usr/local/lib/python3.7/dist-packages/aesara/link/jax/dispatch.py:87: UserWarning: JAX omnistaging couldn't be disabled: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md.
  warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
/usr/local/lib/python3.7/dist-packages/pymc/sampling_jax.py:36: UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")

Compiling...
Compilation time =  0:00:29.181411
Sampling...

sample: 100%|██████████| 4000/4000 [01:28<00:00, 45.18it/s, 31 steps of size 1.80e-01. acc. prob=0.95]

Sampling time =  0:01:31.264422
Transforming variables...
Transformation time =  0:00:00.293751

The pymc version:

import pymc as pm
import pymc.sampling_jax
basic_model = pm.Model()
import numpy as np
number_of_unique_elements_in_grouping_factor_x = int(df["x"].count())
with basic_model:
  beta_0 = pm.Normal("beta_0", mu=0, sigma=2.5)
  mu_0 = pm.Normal("mu_0", mu=0, sigma=pm.HalfNormal("sigma", sigma=2.5), shape=number_of_unique_elements_in_grouping_factor_x)
  x = pm.ConstantData("x", df["x"].to_numpy(), dims="t")
  theta = pm.invlogit(beta_0 + mu_0[x])
  y = pm.ConstantData("y", df["y"].to_numpy(), dims="t")
  Y_obs = pm.Bernoulli("Y_obs", p=theta, observed=df["y"])
with basic_model:
    # draw 2000 posterior samples
    idata_gpu = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=2000, draws=2000, target_accept=0.95, chains = 1)
Compiling...
Compilation time =  0:00:00.928862
Sampling...

sample: 100%|██████████| 4000/4000 [07:12<00:00,  9.25it/s, 63 steps of size 7.52e-02. acc. prob=0.95]

Sampling time =  0:07:12.886486
Transforming variables...
Transformation time =  0:00:00.041902

The colab session is a GPU one, and nvidia-smi shows this (see the notebook for the code that produces this output).

GPU 0: Tesla T4 (UUID: GPU-eee88075-ec18-78e5-b0d3-0c4448bb1523)

FYI: Colab has python 3.7.

EDIT: here is the link for the notebook I used to run on the CPU-only colab instance: Google Colab and here is the output:

with basic_model:
    # draw 2000 posterior samples
    idata_gpu = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=2000, draws=2000, target_accept=0.95, chains = 1)

Compiling...
Compilation time =  0:00:22.985132
Sampling...

sample: 100%|██████████| 4000/4000 [07:59<00:00,  8.34it/s, 63 steps of size 7.52e-02. acc. prob=0.95]

I ran all cells and it seems to use gpu for me

this is the output of the sampling:

Compiling...
Compilation time =  0:00:03.956913
Sampling...

sample: 100%|██████████| 4000/4000 [03:07<00:00, 21.29it/s, 63 steps of size 7.46e-02. acc. prob=0.95]

Sampling time =  0:03:10.977764
Transforming variables...
Transformation time =  0:00:00.083630

edit: i ran it again to double check and it used cpu. unsure what is going on

I think I understand what is happening, and it isn’t a GPU/CPU thing. Bambi is clever enough to use pm.invlogit() only for the unique values of beta_0 + mu_0[x], while my code ignored that there are duplicates in x and thus in beta_0 + mu_0[x].

In real case scenarios this is not a problem, because I can aggregate identical rows and use the binomial distribution instead of bernoulli to achieve the same optimisation. This was only a “pedagogical example”, and indeed I learnt something from it :slight_smile: