Hello,
I have a GCP instance with 4 Tesla T4 GPUs. It is taking 4 hours to sample my data when running on CPU takes only 1 hour.
I found this… Gpu much slower than cpu - Questions - PyMC Discourse
But also read about the possible faster fitting with GPU and jax.
Here is my setup.
<module 'jax.version' from '/opt/conda/lib/python3.7/site-packages/jax/version.py'>
<module 'jaxlib.version' from '/opt/conda/lib/python3.7/site-packages/jaxlib/version.py'>
PyMC Version: 4.0.1
Aesara Version: 2.7.3
Arvize Verions: 0.12.1 Using gpu with [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
Here is my model for reference.
coords={"locations":locations,
"items":items,
'months':months,
'changepoints':df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]],
"yearly_components": [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)],
"obs_id":[f'{loc}_{time.year}_month_{time.month}_item_{item}' for time, loc, item in df_train.index.values]}
with pm.Model(coords=coords) as model:
A = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
s = pm.Data('s', s, mutable=True, dims=['changepoints'])
yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])
# Slope
mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
sigma_loc_slope = pm.HalfNormal('sigma_loc_slope', sigma=0.1)
sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
offset_loc_slope = pm.Normal('offset_loc_slope', mu=0, sigma=0.1, dims=['locations'])
offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])
loc_slope = sigma_loc_slope * offset_loc_slope
item_slope = sigma_item_slope * offset_item_slope
initial_slope = pm.Deterministic('initial_slope', mu_slope + loc_slope[location_idxs] + item_slope[item_idxs],
dims=['obs_id'])
# Intercept
mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
sigma_loc_intercept = pm.HalfNormal('sigma_loc_intercept', sigma=0.1)
sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
offset_loc_intercept = pm.Normal('offset_loc_intercept', mu=0, sigma=0.1, dims=['locations'])
offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])
loc_intercept = sigma_loc_intercept * offset_loc_intercept
item_intercept = sigma_item_intercept * offset_item_intercept
initial_intercept = pm.Deterministic('initial_intercept', mu_intercept + loc_intercept[location_idxs] + item_intercept[item_idxs],
dims=['obs_id'])
# Offsets
mu_delta = pm.Normal('mu_delta', 0, 0.1)
sigma_loc_delta = pm.HalfNormal('sigma_loc_delta', sigma=0.1)
sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
offset_loc_delta = pm.Normal('offset_loc_delta', mu=0, sigma=0.25, dims=['locations', 'changepoints'])
offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])
loc_delta = sigma_loc_delta * offset_loc_delta
item_delta = sigma_item_delta * offset_item_delta
delta = pm.Deterministic('delta', mu_delta + loc_delta[location_idxs, :] + item_delta[item_idxs, :], dims=['obs_id', 'changepoints'])
#monthly seasonality
yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['locations', 'yearly_components'])
yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs] * yearly_beta[location_idxs, :]).sum(axis=1), dims=['obs_id'])
# Monthly Effects
beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])
intercept = initial_intercept + ((-s * A)[time_idxs, :] * delta).sum(axis=1)
slope = initial_slope + (A[time_idxs, :] * delta).sum(axis=1)
mu = pm.Deterministic('mu', intercept + slope * t + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
likelihood = pm.Poisson('predicted_eaches',
mu=np.exp(mu),
observed=y,
dims=['obs_id'],)
prior = pm.sample_prior_predictive()
trace = pymc.sampling_jax.sample_numpyro_nuts(tune=2000, draws = 1000)
Am I doing something wrong that bogs down the GPU more so than the CPUs?