Is there a way to transfer pymc data from GPU to CPU

Hello,

I’m running a model on 4 tesla T4 GPUs which at the end of the sampling time, gives an error that the GPUs are out of memory. I do not get this when I run the model on the CPU…but that takes 9 hours to sample.

Is there a way to transfer the samples (?) or whatever is taking up space from the GPUs to the CPUs? For context my model is below:

time_idxs, times = pd.factorize(df_train.index.get_level_values(0))
location_idxs, locations = pd.factorize(df_train.index.get_level_values(1))
item_idxs, items = pd.factorize(df_train.index.get_level_values(2))
month_idxs, months = pd.factorize(df_train.index.get_level_values(0).month)

t = time_idxs/max(time_idxs)
n_changepoints = 8
s = np.linspace(0, np.max(t), n_changepoints+2)[1:-1]
A = (t[:, None] > s)*1


#target variable
y = np.array(df_train['eaches'])

yearly_fourier = create_fourier_features(t, n=5,  p=12/max(time_idxs))

coords={"locations":locations,
            "items":items,
            'months':months,
            'changepoints':df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]],
            "yearly_components": [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)],
            "obs_id":[f'{loc}_{time.year}_month_{time.month}_item_{item}' for time, loc, item in df_train.index.values]}

with pm.Model(coords=coords) as model:
        
    A_ = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
    s_ = pm.Data('s', s, mutable=True, dims=['changepoints'])
    yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])

    # Slope
    mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
    sigma_loc_slope = pm.HalfNormal('sigma_loc_slope', sigma=0.1)
    sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
    offset_loc_slope = pm.Normal('offset_loc_slope', mu=0, sigma=0.1, dims=['locations'])
    offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])

    loc_slope = sigma_loc_slope * offset_loc_slope
    item_slope = sigma_item_slope * offset_item_slope
    initial_slope = pm.Deterministic('initial_slope', mu_slope + loc_slope[location_idxs] + item_slope[item_idxs],
                                     dims=['obs_id'])

    # Intercept
    mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
    sigma_loc_intercept = pm.HalfNormal('sigma_loc_intercept', sigma=0.1)
    sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
    offset_loc_intercept = pm.Normal('offset_loc_intercept', mu=0, sigma=0.1, dims=['locations'])
    offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])

    loc_intercept = sigma_loc_intercept * offset_loc_intercept
    item_intercept = sigma_item_intercept * offset_item_intercept
    initial_intercept = pm.Deterministic('initial_intercept', mu_intercept + loc_intercept[location_idxs] + item_intercept[item_idxs],
                                         dims=['obs_id'])
    # Offsets
    mu_delta = pm.Normal('mu_delta', 0, 0.1)
    sigma_loc_delta = pm.HalfNormal('sigma_loc_delta', sigma=0.1)
    sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
    offset_loc_delta = pm.Normal('offset_loc_delta', mu=0, sigma=0.25, dims=['locations', 'changepoints'])
    offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])

    loc_delta = sigma_loc_delta * offset_loc_delta
    item_delta = sigma_item_delta * offset_item_delta
    delta = pm.Deterministic('delta', mu_delta + loc_delta[location_idxs, :] + item_delta[item_idxs, :], dims=['obs_id', 'changepoints'])
    
    #monthly seasonality
    yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
    yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
    yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['locations', 'yearly_components'])
    yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs] * yearly_beta[location_idxs, :]).sum(axis=1), dims=['obs_id'])

    # Monthly Effects
    beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])

    intercept = initial_intercept + ((-s_ * A_)[time_idxs, :] * delta).sum(axis=1)
    slope = initial_slope + (A_[time_idxs, :] * delta).sum(axis=1)

    mu = pm.Deterministic('mu', intercept + slope * t + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
    
    likelihood = pm.Poisson('predicted_eaches',
                            mu=pm.math.exp(mu),
                            observed=y,
                            dims=['obs_id'],)
trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, chain_method="vectorized")