Hello,
I’m running a model on 4 tesla T4 GPUs which at the end of the sampling time, gives an error that the GPUs are out of memory. I do not get this when I run the model on the CPU…but that takes 9 hours to sample.
Is there a way to transfer the samples (?) or whatever is taking up space from the GPUs to the CPUs? For context my model is below:
time_idxs, times = pd.factorize(df_train.index.get_level_values(0))
location_idxs, locations = pd.factorize(df_train.index.get_level_values(1))
item_idxs, items = pd.factorize(df_train.index.get_level_values(2))
month_idxs, months = pd.factorize(df_train.index.get_level_values(0).month)
t = time_idxs/max(time_idxs)
n_changepoints = 8
s = np.linspace(0, np.max(t), n_changepoints+2)[1:-1]
A = (t[:, None] > s)*1
#target variable
y = np.array(df_train['eaches'])
yearly_fourier = create_fourier_features(t, n=5, p=12/max(time_idxs))
coords={"locations":locations,
"items":items,
'months':months,
'changepoints':df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]],
"yearly_components": [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)],
"obs_id":[f'{loc}_{time.year}_month_{time.month}_item_{item}' for time, loc, item in df_train.index.values]}
with pm.Model(coords=coords) as model:
A_ = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
s_ = pm.Data('s', s, mutable=True, dims=['changepoints'])
yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])
# Slope
mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
sigma_loc_slope = pm.HalfNormal('sigma_loc_slope', sigma=0.1)
sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
offset_loc_slope = pm.Normal('offset_loc_slope', mu=0, sigma=0.1, dims=['locations'])
offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])
loc_slope = sigma_loc_slope * offset_loc_slope
item_slope = sigma_item_slope * offset_item_slope
initial_slope = pm.Deterministic('initial_slope', mu_slope + loc_slope[location_idxs] + item_slope[item_idxs],
dims=['obs_id'])
# Intercept
mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
sigma_loc_intercept = pm.HalfNormal('sigma_loc_intercept', sigma=0.1)
sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
offset_loc_intercept = pm.Normal('offset_loc_intercept', mu=0, sigma=0.1, dims=['locations'])
offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])
loc_intercept = sigma_loc_intercept * offset_loc_intercept
item_intercept = sigma_item_intercept * offset_item_intercept
initial_intercept = pm.Deterministic('initial_intercept', mu_intercept + loc_intercept[location_idxs] + item_intercept[item_idxs],
dims=['obs_id'])
# Offsets
mu_delta = pm.Normal('mu_delta', 0, 0.1)
sigma_loc_delta = pm.HalfNormal('sigma_loc_delta', sigma=0.1)
sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
offset_loc_delta = pm.Normal('offset_loc_delta', mu=0, sigma=0.25, dims=['locations', 'changepoints'])
offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])
loc_delta = sigma_loc_delta * offset_loc_delta
item_delta = sigma_item_delta * offset_item_delta
delta = pm.Deterministic('delta', mu_delta + loc_delta[location_idxs, :] + item_delta[item_idxs, :], dims=['obs_id', 'changepoints'])
#monthly seasonality
yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['locations', 'yearly_components'])
yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs] * yearly_beta[location_idxs, :]).sum(axis=1), dims=['obs_id'])
# Monthly Effects
beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])
intercept = initial_intercept + ((-s_ * A_)[time_idxs, :] * delta).sum(axis=1)
slope = initial_slope + (A_[time_idxs, :] * delta).sum(axis=1)
mu = pm.Deterministic('mu', intercept + slope * t + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
likelihood = pm.Poisson('predicted_eaches',
mu=pm.math.exp(mu),
observed=y,
dims=['obs_id'],)
trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, chain_method="vectorized")