Has anyone had memory issues with Jax/GPU specifically?

Hello,

I’ve been trying to fit data consisting of 143K’s rows of data on 4 GPUs and keep running into memory issues.

I added the idata_kwargs=dict(log_likelihood = False command with no luck.

Here is the error message.

BufferAssignment stats:
             parameter allocation:    5.12MiB
              constant allocation:         0B
        maybe_live_out allocation:   42.75GiB
     preallocated temp allocation:         0B
                 total allocation:   42.76GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 42.75GiB
		Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=597
		XLA Label: gather
		Shape: f64[4,1000,143464,10]
		==========================

	Buffer 2:
		Size: 4.58MiB
		Entry Parameter Subshape: f64[4,1000,15,10]
		==========================

	Buffer 3:
		Size: 560.4KiB
		Entry Parameter Subshape: s32[143464,1]
		==========================


---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
/tmp/ipykernel_45777/372256884.py in <module>
----> 1 trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
      2 # trace = pm.sample(model=model, tune=2000, draws = 1000,  idata_kwargs=dict(log_likelihood = False))

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    538     jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    539     result = jax.vmap(jax.vmap(jax_fn))(
--> 540         *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
    541     )
    542     mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

    [... skipping hidden 6 frame]

/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in jax_funcified_fgraph(mu_slope, sigma_loc_slope_log_, sigma_item_slope_log_, offset_loc_slope, offset_item_slope, mu_intercept, sigma_loc_intercept_log_, sigma_item_intercept_log_, offset_loc_intercept, offset_item_intercept, mu_delta, sigma_loc_delta_log_, sigma_item_delta_log_, offset_loc_delta, offset_item_delta, yearly_mu, yearly_sigma_log_, yearly_beta, beta_month)
     16     yearly_sigma = exp(yearly_sigma_log_)
     17     # AdvancedSubtensor1(yearly_beta, TensorConstant{[ 0  0  0 .. 14 14 14]})
---> 18     auto_64007 = subtensor(yearly_beta, auto_1489)
     19     # Elemwise{Mul}[(0, 1)](TensorConstant{[[1. 1. 1...0. 0. 0.]]}, AdvancedSubtensor1.0)
     20     auto_64213 = elemwise(auto_63448, auto_64007)

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in subtensor(x, *ilists)
    595             indices = indices[0]
    596 
--> 597         return x.__getitem__(indices)
    598 
    599     return subtensor

    [... skipping hidden 1 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   3596   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   3597   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
-> 3598                  unique_indices, mode, fill_value)
   3599 
   3600 # TODO(phawkins): re-enable jit after fixing excessive recompilation for

/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   3626       unique_indices=unique_indices or indexer.unique_indices,
   3627       indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
-> 3628       mode=mode, fill_value=fill_value)
   3629 
   3630   # Reverses axes with negative strides.

    [... skipping hidden 16 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, *args)
    731     in_flat, token_handler = _add_tokens(has_unordered_effects, ordered_effects,
    732                                          device, in_flat)
--> 733   out_flat = compiled.execute(in_flat)
    734   check_special(name, out_flat)
    735   out_bufs = unflatten(out_flat, output_buffer_counts)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 45908480000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    5.12MiB
              constant allocation:         0B
        maybe_live_out allocation:   42.75GiB
     preallocated temp allocation:         0B
                 total allocation:   42.76GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 42.75GiB
		Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=597
		XLA Label: gather
		Shape: f64[4,1000,143464,10]
		==========================

	Buffer 2:
		Size: 4.58MiB
		Entry Parameter Subshape: f64[4,1000,15,10]
		==========================

	Buffer 3:
		Size: 560.4KiB
		Entry Parameter Subshape: s32[143464,1]
		==========================

I also tried to add aesara.config.floatX = "float32" to reduce the size and get a weird error below.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_4317/372256884.py in <module>
----> 1 trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
      2 # trace = pm.sample(model=model, tune=2000, draws = 1000,  idata_kwargs=dict(log_likelihood = False))

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    521         init_params=init_params,
    522         extra_fields=(
--> 523             "num_steps",
    524             "potential_energy",
    525             "energy",

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597                 states, last_state = _laxmap(partial_map_fn, map_args)
    598             elif self.chain_method == "parallel":
--> 599                 states, last_state = pmap(partial_map_fn)(map_args)
    600             else:
    601                 assert self.chain_method == "vectorized"

    [... skipping hidden 17 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    414             progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    415             diagnostics_fn=diagnostics,
--> 416             num_chains=self.num_chains if self.chain_method == "parallel" else 1,
    417         )
    418         states, last_val = collect_vals

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    337     if not progbar:
    338         last_val, collection, _, _ = fori_loop(
--> 339             0, upper, _body_fn, (init_val, collection, start_idx, thinning)
    340         )
    341     elif num_chains > 1:

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in fori_loop(lower, upper, body_fun, init_val)
    139         return val
    140     else:
--> 141         return lax.fori_loop(lower, upper, body_fun, init_val)
    142 
    143 

    [... skipping hidden 14 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in _body_fn(i, vals)
    321     def _body_fn(i, vals):
    322         val, collection, start_idx, thinning = vals
--> 323         val = body_fun(val)
    324         idx = (i - start_idx) // thinning
    325         collection = cond(

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
    170 def _sample_fn_nojit_args(state, sampler, args, kwargs):
    171     # state is a tuple of size 1 - containing HMCState
--> 172     return (sampler.sample(state[0], args, kwargs),)
    173 
    174 

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
    769         :return: Next `state` after running HMC.
    770         """
--> 771         return self._sample_fn(state, model_args, model_kwargs)
    772 
    773     def __getstate__(self):

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
    472             model_kwargs,
    473             rng_key_transition,
--> 474             *hmc_length_args,
    475         )
    476         # not update adapt_state after warmup phase

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, max_treedepth_current)
    413             rng_key,
    414             max_delta_energy=max_delta_energy,
--> 415             max_tree_depth=(max_treedepth_current, max(max_treedepth)),
    416         )
    417         accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
   1175 
   1176     state = (tree, rng_key)
-> 1177     tree, _ = while_loop(_cond_fn, _body_fn, state)
   1178     return tree
   1179 

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    129         return val
    130     else:
--> 131         return lax.while_loop(cond_fun, body_fun, init_val)
    132 
    133 

    [... skipping hidden 11 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
   1170             max_delta_energy,
   1171             r_ckpts,
-> 1172             r_sum_ckpts,
   1173         )
   1174         return tree, key

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
    925         max_delta_energy,
    926         r_ckpts,
--> 927         r_sum_ckpts,
    928     )
    929 

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
   1060 
   1061     tree, turning, _, _, _ = while_loop(
-> 1062         _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng_key)
   1063     )
   1064     # update depth and turning condition

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    129         return val
    130     else:
--> 131         return lax.while_loop(cond_fun, body_fun, init_val)
    132 
    133 

    [... skipping hidden 11 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
   1027                 transition_rng_key,
   1028             ),
-> 1029             lambda x: _combine_tree(*x, False),
   1030         )
   1031 

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in cond(pred, true_operand, true_fun, false_operand, false_fun)
    119             return false_fun(false_operand)
    120     else:
--> 121         return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
    122 
    123 

    [... skipping hidden 16 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in <lambda>(x)
   1027                 transition_rng_key,
   1028             ),
-> 1029             lambda x: _combine_tree(*x, False),
   1030         )
   1031 

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition)
    776         ),
    777         (new_tree, current_tree),
--> 778         lambda trees: (
    779             trees[0].z_left,
    780             trees[0].r_left,

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in cond(pred, true_operand, true_fun, false_operand, false_fun)
    119             return false_fun(false_operand)
    120     else:
--> 121         return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
    122 
    123 

    [... skipping hidden 4 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/lax/control_flow/common.py in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
    103     diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
    104                     tree_unflatten(tree2, avals2))
--> 105     raise TypeError(f"{what} must have identical types, got\n{diff}.")
    106 
    107 

TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,8]) vs. ShapedArray(float64[15,8])', 'DIFFERENT ShapedArray(float32[217,8]) vs. ShapedArray(float64[217,8])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,10]) vs. ShapedArray(float64[15,10])', 'DIFFERENT ShapedArray(float32[12]) vs. ShapedArray(float64[12])'], ['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,8])', 'ShapedArray(float64[217,8])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,10])', 'ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,8]) vs. ShapedArray(float64[15,8])', 'DIFFERENT ShapedArray(float32[217,8]) vs. ShapedArray(float64[217,8])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,10]) vs. ShapedArray(float64[15,10])', 'DIFFERENT ShapedArray(float32[12]) vs. ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,8]) vs. ShapedArray(float32[15,8])', 'DIFFERENT ShapedArray(float64[217,8]) vs. ShapedArray(float32[217,8])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,10]) vs. ShapedArray(float32[15,10])', 'DIFFERENT ShapedArray(float64[12]) vs. ShapedArray(float32[12])'], ['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,8])', 'ShapedArray(float64[217,8])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,10])', 'ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,8]) vs. ShapedArray(float32[15,8])', 'DIFFERENT ShapedArray(float64[217,8]) vs. ShapedArray(float32[217,8])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,10]) vs. ShapedArray(float32[15,10])', 'DIFFERENT ShapedArray(float64[12]) vs. ShapedArray(float32[12])']).

The model will run on CPU but takes 12 hours to fit. The GPUs sample faster, then error out. Is this truly a memory issue or is there a bug somewhere with hax specficially?

Update: This will run with 5,400-ish rows but not 17K-ish rows. This still seems like too little. Maybe the model is too big? See below for context.

            "items":items,
            'months':months,
            'changepoints':df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]],
            "yearly_components": [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)],
            "obs_id":[f'{loc}_{time.year}_month_{time.month}_item_{item}' for time, loc, item in df_train.index.values]}

with pm.Model(coords=coords) as model:
        
    A_ = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
    s_ = pm.Data('s', s, mutable=True, dims=['changepoints'])
    t_ = pm.Data('t', t, mutable=True, dims=['time'])
    yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])

#     # Slope
    mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
    sigma_loc_slope = pm.HalfNormal('sigma_loc_slope', sigma=0.1)
    sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
    offset_loc_slope = pm.Normal('offset_loc_slope', mu=0, sigma=0.1, dims=['locations'])
    offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])

    loc_slope = sigma_loc_slope * offset_loc_slope
    item_slope = sigma_item_slope * offset_item_slope
    initial_slope = mu_slope + loc_slope[location_idxs] + item_slope[item_idxs]
                                     

    # Intercept
    mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
    sigma_loc_intercept = pm.HalfNormal('sigma_loc_intercept', sigma=0.1)
    sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
    offset_loc_intercept = pm.Normal('offset_loc_intercept', mu=0, sigma=0.1, dims=['locations'])
    offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])
    

    loc_intercept = sigma_loc_intercept * offset_loc_intercept
    item_intercept = sigma_item_intercept * offset_item_intercept
    initial_intercept = mu_intercept + loc_intercept[location_idxs] + item_intercept[item_idxs]
    
    # Offsets
    mu_delta = pm.Normal('mu_delta', 0, 0.1)
    sigma_loc_delta = pm.HalfNormal('sigma_loc_delta', sigma=0.1)
    sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
    offset_loc_delta = pm.Normal('offset_loc_delta', mu=0, sigma=0.25, dims=['locations', 'changepoints'])
    offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])

    loc_delta = sigma_loc_delta * offset_loc_delta
    item_delta = sigma_item_delta * offset_item_delta
    delta = mu_delta + loc_delta[location_idxs, :] + item_delta[item_idxs, :]
    
    #monthly seasonality
    yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
    yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
    yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['locations', 'yearly_components'])
    yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs] * yearly_beta[location_idxs, :]).sum(axis=1), dims=['obs_id'])

    # Monthly Effects
    beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])

    intercept = initial_intercept + ((-s_ * A_)[time_idxs, :] * delta).sum(axis=1)
    slope = initial_slope + (A_[time_idxs, :] * delta).sum(axis=1)

    mu = pm.Deterministic('mu', intercept + slope * t_ + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
    
    likelihood = pm.Poisson('predicted_eaches',
                            mu=pm.math.exp(mu),
                            observed=y,
                            dims=['obs_id'],)

trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, chain_method='vectorized', draws = 1000, idata_kwargs=dict(log_likelihood = False))

I was seeing the same memory error on a machine with 4 GPUs and 72K rows of data using the following sampling setup:

trace = pymc.sampling_jax.sample_numpyro_nuts(postprocessing_backend="gpu", chain_method="parallel")

Adding the idata keyword argument DID work for me:

trace = pymc.sampling_jax.sample_numpyro_nuts(
    postprocessing_backend="gpu", chain_method="parallel", idata_kwargs={"log_likelihood": False}
)

But notice that I’m also running the sampling in parallel rather than vectorized; not sure if that will make a difference for you. Good luck!

1 Like

Try this sequence:

import aesara
aesara.config.floatX="float32"

before running

pm.Model

maybe even before running

import pymc

(but I’m not sure if that is required).

Also, use a memory allocation scheme that de-allocates memory no longer needed, in a jupyter notebook use

%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

This must be set before any jax method is run, e.g.

jax.default_backend()

As for the idata_kwargs part I use this:

idata_kwargs=dict(log_likelihood=False)

but you have that already.

Only sample a subset of the number of draws you intend to collect also decreases the ram requirements, but then you have to restart which is inconvenient and not trivial.

Last thing to check: inspect the posterior and make sure only the parameters you need are saved.