Hello,
I am trying to run an MMM using Pymc Marketing. As the model is taking so long to run onCPU, I am using 2 NVIDIA T4 GPUs.
I am using as sampler in the fit function numpyro or blackjax but i don’t see an speed improvement (i have tested that the model is in fact using the GPUs)
Is it something I am mising?
Here is how i am configuring the fit:
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
sampler_kwargs = {
"draws": 5_000,
"chains": 2,
"random_seed": rng,
"tune": 1000,
"progressbar":True
}
mmm.fit(X=X, y=y, nuts_sampler="blackjax",**sampler_kwargs)
Thank you