Using GPUs correctly in Pymc Marketing

Hello,

I am trying to run an MMM using Pymc Marketing. As the model is taking so long to run onCPU, I am using 2 NVIDIA T4 GPUs.

I am using as sampler in the fit function numpyro or blackjax but i don’t see an speed improvement (i have tested that the model is in fact using the GPUs)

Is it something I am mising?

Here is how i am configuring the fit:

seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)


sampler_kwargs = {
    "draws": 5_000,
    "chains": 2,
    "random_seed": rng,
    "tune": 1000,
    "progressbar":True    
}

mmm.fit(X=X, y=y, nuts_sampler="blackjax",**sampler_kwargs)

Thank you