I’m learning pymc and bambi (coming from R/lme4), and for a starter I implemented a simple multi-level model in all three framworks: lme4, bambi and pymc.

The code runs as it should on colab, available here, but surprisingly, bambi is faster then my pure pymc implementation. The speed I get with pymc is the same with and without GPU (I have created a CPU only version and run it on a Colab instance without GPU), so it seems I failed to get pymc to run on the GPU on Colab. Any hints on how to get jax to use the GPU on Colab are appreciated.

The notebook on Colab is rather lengthy, so I’ll paste the important stuff here:

The bambi version:

```
my_model = bmb.Model("y ~ 1 + (1 | x)", data = df, family = "bernoulli", dropna=True)
my_fit = my_model.fit(random_seed=1234, tune=2000, draws=2000, target_accept=0.95, method = "nuts_numpyro", chains = 1, chain_method="sequential")
Modeling the probability that y==1
/usr/local/lib/python3.7/dist-packages/aesara/link/jax/dispatch.py:87: UserWarning: JAX omnistaging couldn't be disabled: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md.
warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
/usr/local/lib/python3.7/dist-packages/pymc/sampling_jax.py:36: UserWarning: This module is experimental.
warnings.warn("This module is experimental.")
Compiling...
Compilation time = 0:00:29.181411
Sampling...
sample: 100%|██████████| 4000/4000 [01:28<00:00, 45.18it/s, 31 steps of size 1.80e-01. acc. prob=0.95]
Sampling time = 0:01:31.264422
Transforming variables...
Transformation time = 0:00:00.293751
```

The pymc version:

```
import pymc as pm
import pymc.sampling_jax
basic_model = pm.Model()
import numpy as np
number_of_unique_elements_in_grouping_factor_x = int(df["x"].count())
with basic_model:
beta_0 = pm.Normal("beta_0", mu=0, sigma=2.5)
mu_0 = pm.Normal("mu_0", mu=0, sigma=pm.HalfNormal("sigma", sigma=2.5), shape=number_of_unique_elements_in_grouping_factor_x)
x = pm.ConstantData("x", df["x"].to_numpy(), dims="t")
theta = pm.invlogit(beta_0 + mu_0[x])
y = pm.ConstantData("y", df["y"].to_numpy(), dims="t")
Y_obs = pm.Bernoulli("Y_obs", p=theta, observed=df["y"])
with basic_model:
# draw 2000 posterior samples
idata_gpu = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=2000, draws=2000, target_accept=0.95, chains = 1)
Compiling...
Compilation time = 0:00:00.928862
Sampling...
sample: 100%|██████████| 4000/4000 [07:12<00:00, 9.25it/s, 63 steps of size 7.52e-02. acc. prob=0.95]
Sampling time = 0:07:12.886486
Transforming variables...
Transformation time = 0:00:00.041902
```

The colab session is a GPU one, and nvidia-smi shows this (see the notebook for the code that produces this output).

```
GPU 0: Tesla T4 (UUID: GPU-eee88075-ec18-78e5-b0d3-0c4448bb1523)
```

FYI: Colab has python 3.7.

EDIT: here is the link for the notebook I used to run on the CPU-only colab instance: Google Colab and here is the output:

```
with basic_model:
# draw 2000 posterior samples
idata_gpu = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=2000, draws=2000, target_accept=0.95, chains = 1)
Compiling...
Compilation time = 0:00:22.985132
Sampling...
sample: 100%|██████████| 4000/4000 [07:59<00:00, 8.34it/s, 63 steps of size 7.52e-02. acc. prob=0.95]
```