Has anyone had memory issues with Jax/GPU specifically?

Hello,

I’ve been trying to fit data consisting of 143K’s rows of data on 4 GPUs and keep running into memory issues.

I added the idata_kwargs=dict(log_likelihood = False command with no luck.

Here is the error message.

BufferAssignment stats:
             parameter allocation:    5.12MiB
              constant allocation:         0B
        maybe_live_out allocation:   42.75GiB
     preallocated temp allocation:         0B
                 total allocation:   42.76GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 42.75GiB
		Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=597
		XLA Label: gather
		Shape: f64[4,1000,143464,10]
		==========================

	Buffer 2:
		Size: 4.58MiB
		Entry Parameter Subshape: f64[4,1000,15,10]
		==========================

	Buffer 3:
		Size: 560.4KiB
		Entry Parameter Subshape: s32[143464,1]
		==========================


---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
/tmp/ipykernel_45777/372256884.py in <module>
----> 1 trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
      2 # trace = pm.sample(model=model, tune=2000, draws = 1000,  idata_kwargs=dict(log_likelihood = False))

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    538     jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    539     result = jax.vmap(jax.vmap(jax_fn))(
--> 540         *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
    541     )
    542     mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

    [... skipping hidden 6 frame]

/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in jax_funcified_fgraph(mu_slope, sigma_loc_slope_log_, sigma_item_slope_log_, offset_loc_slope, offset_item_slope, mu_intercept, sigma_loc_intercept_log_, sigma_item_intercept_log_, offset_loc_intercept, offset_item_intercept, mu_delta, sigma_loc_delta_log_, sigma_item_delta_log_, offset_loc_delta, offset_item_delta, yearly_mu, yearly_sigma_log_, yearly_beta, beta_month)
     16     yearly_sigma = exp(yearly_sigma_log_)
     17     # AdvancedSubtensor1(yearly_beta, TensorConstant{[ 0  0  0 .. 14 14 14]})
---> 18     auto_64007 = subtensor(yearly_beta, auto_1489)
     19     # Elemwise{Mul}[(0, 1)](TensorConstant{[[1. 1. 1...0. 0. 0.]]}, AdvancedSubtensor1.0)
     20     auto_64213 = elemwise(auto_63448, auto_64007)

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in subtensor(x, *ilists)
    595             indices = indices[0]
    596 
--> 597         return x.__getitem__(indices)
    598 
    599     return subtensor

    [... skipping hidden 1 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   3596   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   3597   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
-> 3598                  unique_indices, mode, fill_value)
   3599 
   3600 # TODO(phawkins): re-enable jit after fixing excessive recompilation for

/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   3626       unique_indices=unique_indices or indexer.unique_indices,
   3627       indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
-> 3628       mode=mode, fill_value=fill_value)
   3629 
   3630   # Reverses axes with negative strides.

    [... skipping hidden 16 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, *args)
    731     in_flat, token_handler = _add_tokens(has_unordered_effects, ordered_effects,
    732                                          device, in_flat)
--> 733   out_flat = compiled.execute(in_flat)
    734   check_special(name, out_flat)
    735   out_bufs = unflatten(out_flat, output_buffer_counts)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 45908480000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    5.12MiB
              constant allocation:         0B
        maybe_live_out allocation:   42.75GiB
     preallocated temp allocation:         0B
                 total allocation:   42.76GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 42.75GiB
		Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=597
		XLA Label: gather
		Shape: f64[4,1000,143464,10]
		==========================

	Buffer 2:
		Size: 4.58MiB
		Entry Parameter Subshape: f64[4,1000,15,10]
		==========================

	Buffer 3:
		Size: 560.4KiB
		Entry Parameter Subshape: s32[143464,1]
		==========================

I also tried to add aesara.config.floatX = "float32" to reduce the size and get a weird error below.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_4317/372256884.py in <module>
----> 1 trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
      2 # trace = pm.sample(model=model, tune=2000, draws = 1000,  idata_kwargs=dict(log_likelihood = False))

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    521         init_params=init_params,
    522         extra_fields=(
--> 523             "num_steps",
    524             "potential_energy",
    525             "energy",

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597                 states, last_state = _laxmap(partial_map_fn, map_args)
    598             elif self.chain_method == "parallel":
--> 599                 states, last_state = pmap(partial_map_fn)(map_args)
    600             else:
    601                 assert self.chain_method == "vectorized"

    [... skipping hidden 17 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    414             progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
    415             diagnostics_fn=diagnostics,
--> 416             num_chains=self.num_chains if self.chain_method == "parallel" else 1,
    417         )
    418         states, last_val = collect_vals

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    337     if not progbar:
    338         last_val, collection, _, _ = fori_loop(
--> 339             0, upper, _body_fn, (init_val, collection, start_idx, thinning)
    340         )
    341     elif num_chains > 1:

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in fori_loop(lower, upper, body_fun, init_val)
    139         return val
    140     else:
--> 141         return lax.fori_loop(lower, upper, body_fun, init_val)
    142 
    143 

    [... skipping hidden 14 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in _body_fn(i, vals)
    321     def _body_fn(i, vals):
    322         val, collection, start_idx, thinning = vals
--> 323         val = body_fun(val)
    324         idx = (i - start_idx) // thinning
    325         collection = cond(

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
    170 def _sample_fn_nojit_args(state, sampler, args, kwargs):
    171     # state is a tuple of size 1 - containing HMCState
--> 172     return (sampler.sample(state[0], args, kwargs),)
    173 
    174 

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
    769         :return: Next `state` after running HMC.
    770         """
--> 771         return self._sample_fn(state, model_args, model_kwargs)
    772 
    773     def __getstate__(self):

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
    472             model_kwargs,
    473             rng_key_transition,
--> 474             *hmc_length_args,
    475         )
    476         # not update adapt_state after warmup phase

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, max_treedepth_current)
    413             rng_key,
    414             max_delta_energy=max_delta_energy,
--> 415             max_tree_depth=(max_treedepth_current, max(max_treedepth)),
    416         )
    417         accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
   1175 
   1176     state = (tree, rng_key)
-> 1177     tree, _ = while_loop(_cond_fn, _body_fn, state)
   1178     return tree
   1179 

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    129         return val
    130     else:
--> 131         return lax.while_loop(cond_fun, body_fun, init_val)
    132 
    133 

    [... skipping hidden 11 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
   1170             max_delta_energy,
   1171             r_ckpts,
-> 1172             r_sum_ckpts,
   1173         )
   1174         return tree, key

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
    925         max_delta_energy,
    926         r_ckpts,
--> 927         r_sum_ckpts,
    928     )
    929 

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
   1060 
   1061     tree, turning, _, _, _ = while_loop(
-> 1062         _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng_key)
   1063     )
   1064     # update depth and turning condition

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
    129         return val
    130     else:
--> 131         return lax.while_loop(cond_fun, body_fun, init_val)
    132 
    133 

    [... skipping hidden 11 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
   1027                 transition_rng_key,
   1028             ),
-> 1029             lambda x: _combine_tree(*x, False),
   1030         )
   1031 

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in cond(pred, true_operand, true_fun, false_operand, false_fun)
    119             return false_fun(false_operand)
    120     else:
--> 121         return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
    122 
    123 

    [... skipping hidden 16 frame]

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in <lambda>(x)
   1027                 transition_rng_key,
   1028             ),
-> 1029             lambda x: _combine_tree(*x, False),
   1030         )
   1031 

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition)
    776         ),
    777         (new_tree, current_tree),
--> 778         lambda trees: (
    779             trees[0].z_left,
    780             trees[0].r_left,

/opt/conda/lib/python3.7/site-packages/numpyro/util.py in cond(pred, true_operand, true_fun, false_operand, false_fun)
    119             return false_fun(false_operand)
    120     else:
--> 121         return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
    122 
    123 

    [... skipping hidden 4 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/lax/control_flow/common.py in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
    103     diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
    104                     tree_unflatten(tree2, avals2))
--> 105     raise TypeError(f"{what} must have identical types, got\n{diff}.")
    106 
    107 

TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,8]) vs. ShapedArray(float64[15,8])', 'DIFFERENT ShapedArray(float32[217,8]) vs. ShapedArray(float64[217,8])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,10]) vs. ShapedArray(float64[15,10])', 'DIFFERENT ShapedArray(float32[12]) vs. ShapedArray(float64[12])'], ['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,8])', 'ShapedArray(float64[217,8])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,10])', 'ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,8]) vs. ShapedArray(float64[15,8])', 'DIFFERENT ShapedArray(float32[217,8]) vs. ShapedArray(float64[217,8])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,10]) vs. ShapedArray(float64[15,10])', 'DIFFERENT ShapedArray(float32[12]) vs. ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,8]) vs. ShapedArray(float32[15,8])', 'DIFFERENT ShapedArray(float64[217,8]) vs. ShapedArray(float32[217,8])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,10]) vs. ShapedArray(float32[15,10])', 'DIFFERENT ShapedArray(float64[12]) vs. ShapedArray(float32[12])'], ['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,8])', 'ShapedArray(float64[217,8])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,10])', 'ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,8]) vs. ShapedArray(float32[15,8])', 'DIFFERENT ShapedArray(float64[217,8]) vs. ShapedArray(float32[217,8])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,10]) vs. ShapedArray(float32[15,10])', 'DIFFERENT ShapedArray(float64[12]) vs. ShapedArray(float32[12])']).

The model will run on CPU but takes 12 hours to fit. The GPUs sample faster, then error out. Is this truly a memory issue or is there a bug somewhere with hax specficially?

Update: This will run with 5,400-ish rows but not 17K-ish rows. This still seems like too little. Maybe the model is too big? See below for context.

            "items":items,
            'months':months,
            'changepoints':df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]],
            "yearly_components": [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)],
            "obs_id":[f'{loc}_{time.year}_month_{time.month}_item_{item}' for time, loc, item in df_train.index.values]}

with pm.Model(coords=coords) as model:
        
    A_ = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
    s_ = pm.Data('s', s, mutable=True, dims=['changepoints'])
    t_ = pm.Data('t', t, mutable=True, dims=['time'])
    yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])

#     # Slope
    mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
    sigma_loc_slope = pm.HalfNormal('sigma_loc_slope', sigma=0.1)
    sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
    offset_loc_slope = pm.Normal('offset_loc_slope', mu=0, sigma=0.1, dims=['locations'])
    offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])

    loc_slope = sigma_loc_slope * offset_loc_slope
    item_slope = sigma_item_slope * offset_item_slope
    initial_slope = mu_slope + loc_slope[location_idxs] + item_slope[item_idxs]
                                     

    # Intercept
    mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
    sigma_loc_intercept = pm.HalfNormal('sigma_loc_intercept', sigma=0.1)
    sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
    offset_loc_intercept = pm.Normal('offset_loc_intercept', mu=0, sigma=0.1, dims=['locations'])
    offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])
    

    loc_intercept = sigma_loc_intercept * offset_loc_intercept
    item_intercept = sigma_item_intercept * offset_item_intercept
    initial_intercept = mu_intercept + loc_intercept[location_idxs] + item_intercept[item_idxs]
    
    # Offsets
    mu_delta = pm.Normal('mu_delta', 0, 0.1)
    sigma_loc_delta = pm.HalfNormal('sigma_loc_delta', sigma=0.1)
    sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
    offset_loc_delta = pm.Normal('offset_loc_delta', mu=0, sigma=0.25, dims=['locations', 'changepoints'])
    offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])

    loc_delta = sigma_loc_delta * offset_loc_delta
    item_delta = sigma_item_delta * offset_item_delta
    delta = mu_delta + loc_delta[location_idxs, :] + item_delta[item_idxs, :]
    
    #monthly seasonality
    yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
    yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
    yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['locations', 'yearly_components'])
    yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs] * yearly_beta[location_idxs, :]).sum(axis=1), dims=['obs_id'])

    # Monthly Effects
    beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])

    intercept = initial_intercept + ((-s_ * A_)[time_idxs, :] * delta).sum(axis=1)
    slope = initial_slope + (A_[time_idxs, :] * delta).sum(axis=1)

    mu = pm.Deterministic('mu', intercept + slope * t_ + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
    
    likelihood = pm.Poisson('predicted_eaches',
                            mu=pm.math.exp(mu),
                            observed=y,
                            dims=['obs_id'],)

trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, chain_method='vectorized', draws = 1000, idata_kwargs=dict(log_likelihood = False))

I was seeing the same memory error on a machine with 4 GPUs and 72K rows of data using the following sampling setup:

trace = pymc.sampling_jax.sample_numpyro_nuts(postprocessing_backend="gpu", chain_method="parallel")

Adding the idata keyword argument DID work for me:

trace = pymc.sampling_jax.sample_numpyro_nuts(
    postprocessing_backend="gpu", chain_method="parallel", idata_kwargs={"log_likelihood": False}
)

But notice that I’m also running the sampling in parallel rather than vectorized; not sure if that will make a difference for you. Good luck!

1 Like

Try this sequence:

import aesara
aesara.config.floatX="float32"

before running

pm.Model

maybe even before running

import pymc

(but I’m not sure if that is required).

Also, use a memory allocation scheme that de-allocates memory no longer needed, in a jupyter notebook use

%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

This must be set before any jax method is run, e.g.

jax.default_backend()

As for the idata_kwargs part I use this:

idata_kwargs=dict(log_likelihood=False)

but you have that already.

Only sample a subset of the number of draws you intend to collect also decreases the ram requirements, but then you have to restart which is inconvenient and not trivial.

Last thing to check: inspect the posterior and make sure only the parameters you need are saved.

Sorry for the late reply. I was on holiday. I tried this and still had some issues. I’ll just stick with the CPU for now.

I have found that disabling the preallocation of jax helps with problems like this:

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"

However, that does make the sampling quite slow for me and the CPU version seems to be faster.

This worked one time, with GP’s MarginalApprox(), however if I use any more than ~500 inducing points I run into memory errors again, so its an improvement but can’t be a solution.

Has anyone else had other tries? Is the GP module recommended in general? Not as a criticism but observation and thus genuine question, in a fraction of the time I was able to build similar models and run on the same data in R, gpytorch and sklearn.
What benefits do I get from using the pymc implementation?

I really kinda want to use pymc since the linear models were so much fun but as some might have noticed I am running into some issue at almost every step of the way.

I kept getting Out of Of Memory (OOM) errors when I was sampling. But after setting

%env XLA_PYTHON_CLIENT_PREALLOCATE = false

and using:

idata_logit = pm.sampling_jax.sample_numpyro_nuts(draws=1000, chains=4, chain_method='parallel', progress_bar= True, postprocessing_backend="cpu") 

The GPU model ran in 13 min compared to the 45 min on a CPU. When I set postprocessing_backend to GPU I kept getting errors, which I traced back to Jax and found lots of #TODO comments in the backend_compile() function that was throwing the error. So I think there’s something wacky going on around there.