Hello,
I’ve been trying to fit data consisting of 143K’s rows of data on 4 GPUs and keep running into memory issues.
I added the idata_kwargs=dict(log_likelihood = False
command with no luck.
Here is the error message.
BufferAssignment stats:
parameter allocation: 5.12MiB
constant allocation: 0B
maybe_live_out allocation: 42.75GiB
preallocated temp allocation: 0B
total allocation: 42.76GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 42.75GiB
Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=597
XLA Label: gather
Shape: f64[4,1000,143464,10]
==========================
Buffer 2:
Size: 4.58MiB
Entry Parameter Subshape: f64[4,1000,15,10]
==========================
Buffer 3:
Size: 560.4KiB
Entry Parameter Subshape: s32[143464,1]
==========================
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
/tmp/ipykernel_45777/372256884.py in <module>
----> 1 trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
2 # trace = pm.sample(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
538 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
539 result = jax.vmap(jax.vmap(jax_fn))(
--> 540 *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
541 )
542 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
[... skipping hidden 6 frame]
/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in jax_funcified_fgraph(mu_slope, sigma_loc_slope_log_, sigma_item_slope_log_, offset_loc_slope, offset_item_slope, mu_intercept, sigma_loc_intercept_log_, sigma_item_intercept_log_, offset_loc_intercept, offset_item_intercept, mu_delta, sigma_loc_delta_log_, sigma_item_delta_log_, offset_loc_delta, offset_item_delta, yearly_mu, yearly_sigma_log_, yearly_beta, beta_month)
16 yearly_sigma = exp(yearly_sigma_log_)
17 # AdvancedSubtensor1(yearly_beta, TensorConstant{[ 0 0 0 .. 14 14 14]})
---> 18 auto_64007 = subtensor(yearly_beta, auto_1489)
19 # Elemwise{Mul}[(0, 1)](TensorConstant{[[1. 1. 1...0. 0. 0.]]}, AdvancedSubtensor1.0)
20 auto_64213 = elemwise(auto_63448, auto_64007)
/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in subtensor(x, *ilists)
595 indices = indices[0]
596
--> 597 return x.__getitem__(indices)
598
599 return subtensor
[... skipping hidden 1 frame]
/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
3596 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
3597 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
-> 3598 unique_indices, mode, fill_value)
3599
3600 # TODO(phawkins): re-enable jit after fixing excessive recompilation for
/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
3626 unique_indices=unique_indices or indexer.unique_indices,
3627 indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
-> 3628 mode=mode, fill_value=fill_value)
3629
3630 # Reverses axes with negative strides.
[... skipping hidden 16 frame]
/opt/conda/lib/python3.7/site-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, *args)
731 in_flat, token_handler = _add_tokens(has_unordered_effects, ordered_effects,
732 device, in_flat)
--> 733 out_flat = compiled.execute(in_flat)
734 check_special(name, out_flat)
735 out_bufs = unflatten(out_flat, output_buffer_counts)
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 45908480000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 5.12MiB
constant allocation: 0B
maybe_live_out allocation: 42.75GiB
preallocated temp allocation: 0B
total allocation: 42.76GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 42.75GiB
Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=597
XLA Label: gather
Shape: f64[4,1000,143464,10]
==========================
Buffer 2:
Size: 4.58MiB
Entry Parameter Subshape: f64[4,1000,15,10]
==========================
Buffer 3:
Size: 560.4KiB
Entry Parameter Subshape: s32[143464,1]
==========================
I also tried to add aesara.config.floatX = "float32"
to reduce the size and get a weird error below.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_4317/372256884.py in <module>
----> 1 trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
2 # trace = pm.sample(model=model, tune=2000, draws = 1000, idata_kwargs=dict(log_likelihood = False))
/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
521 init_params=init_params,
522 extra_fields=(
--> 523 "num_steps",
524 "potential_energy",
525 "energy",
/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == "vectorized"
[... skipping hidden 17 frame]
/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
414 progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
415 diagnostics_fn=diagnostics,
--> 416 num_chains=self.num_chains if self.chain_method == "parallel" else 1,
417 )
418 states, last_val = collect_vals
/opt/conda/lib/python3.7/site-packages/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
337 if not progbar:
338 last_val, collection, _, _ = fori_loop(
--> 339 0, upper, _body_fn, (init_val, collection, start_idx, thinning)
340 )
341 elif num_chains > 1:
/opt/conda/lib/python3.7/site-packages/numpyro/util.py in fori_loop(lower, upper, body_fun, init_val)
139 return val
140 else:
--> 141 return lax.fori_loop(lower, upper, body_fun, init_val)
142
143
[... skipping hidden 14 frame]
/opt/conda/lib/python3.7/site-packages/numpyro/util.py in _body_fn(i, vals)
321 def _body_fn(i, vals):
322 val, collection, start_idx, thinning = vals
--> 323 val = body_fun(val)
324 idx = (i - start_idx) // thinning
325 collection = cond(
/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _sample_fn_nojit_args(state, sampler, args, kwargs)
170 def _sample_fn_nojit_args(state, sampler, args, kwargs):
171 # state is a tuple of size 1 - containing HMCState
--> 172 return (sampler.sample(state[0], args, kwargs),)
173
174
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in sample(self, state, model_args, model_kwargs)
769 :return: Next `state` after running HMC.
770 """
--> 771 return self._sample_fn(state, model_args, model_kwargs)
772
773 def __getstate__(self):
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in sample_kernel(hmc_state, model_args, model_kwargs)
472 model_kwargs,
473 rng_key_transition,
--> 474 *hmc_length_args,
475 )
476 # not update adapt_state after warmup phase
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in _nuts_next(step_size, inverse_mass_matrix, vv_state, model_args, model_kwargs, rng_key, max_treedepth_current)
413 rng_key,
414 max_delta_energy=max_delta_energy,
--> 415 max_tree_depth=(max_treedepth_current, max(max_treedepth)),
416 )
417 accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng_key, max_delta_energy, max_tree_depth)
1175
1176 state = (tree, rng_key)
-> 1177 tree, _ = while_loop(_cond_fn, _body_fn, state)
1178 return tree
1179
/opt/conda/lib/python3.7/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
129 return val
130 else:
--> 131 return lax.while_loop(cond_fun, body_fun, init_val)
132
133
[... skipping hidden 11 frame]
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
1170 max_delta_energy,
1171 r_ckpts,
-> 1172 r_sum_ckpts,
1173 )
1174 return tree, key
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
925 max_delta_energy,
926 r_ckpts,
--> 927 r_sum_ckpts,
928 )
929
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _iterative_build_subtree(prototype_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size, going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts)
1060
1061 tree, turning, _, _, _ = while_loop(
-> 1062 _cond_fn, _body_fn, (basetree, False, r_ckpts, r_sum_ckpts, rng_key)
1063 )
1064 # update depth and turning condition
/opt/conda/lib/python3.7/site-packages/numpyro/util.py in while_loop(cond_fun, body_fun, init_val)
129 return val
130 else:
--> 131 return lax.while_loop(cond_fun, body_fun, init_val)
132
133
[... skipping hidden 11 frame]
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _body_fn(state)
1027 transition_rng_key,
1028 ),
-> 1029 lambda x: _combine_tree(*x, False),
1030 )
1031
/opt/conda/lib/python3.7/site-packages/numpyro/util.py in cond(pred, true_operand, true_fun, false_operand, false_fun)
119 return false_fun(false_operand)
120 else:
--> 121 return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
122
123
[... skipping hidden 16 frame]
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in <lambda>(x)
1027 transition_rng_key,
1028 ),
-> 1029 lambda x: _combine_tree(*x, False),
1030 )
1031
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc_util.py in _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng_key, biased_transition)
776 ),
777 (new_tree, current_tree),
--> 778 lambda trees: (
779 trees[0].z_left,
780 trees[0].r_left,
/opt/conda/lib/python3.7/site-packages/numpyro/util.py in cond(pred, true_operand, true_fun, false_operand, false_fun)
119 return false_fun(false_operand)
120 else:
--> 121 return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
122
123
[... skipping hidden 4 frame]
/opt/conda/lib/python3.7/site-packages/jax/_src/lax/control_flow/common.py in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
103 diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
104 tree_unflatten(tree2, avals2))
--> 105 raise TypeError(f"{what} must have identical types, got\n{diff}.")
106
107
TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,8]) vs. ShapedArray(float64[15,8])', 'DIFFERENT ShapedArray(float32[217,8]) vs. ShapedArray(float64[217,8])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,10]) vs. ShapedArray(float64[15,10])', 'DIFFERENT ShapedArray(float32[12]) vs. ShapedArray(float64[12])'], ['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,8])', 'ShapedArray(float64[217,8])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,10])', 'ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15]) vs. ShapedArray(float64[15])', 'DIFFERENT ShapedArray(float32[217]) vs. ShapedArray(float64[217])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,8]) vs. ShapedArray(float64[15,8])', 'DIFFERENT ShapedArray(float32[217,8]) vs. ShapedArray(float64[217,8])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])', 'DIFFERENT ShapedArray(float32[15,10]) vs. ShapedArray(float64[15,10])', 'DIFFERENT ShapedArray(float32[12]) vs. ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,8]) vs. ShapedArray(float32[15,8])', 'DIFFERENT ShapedArray(float64[217,8]) vs. ShapedArray(float32[217,8])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,10]) vs. ShapedArray(float32[15,10])', 'DIFFERENT ShapedArray(float64[12]) vs. ShapedArray(float32[12])'], ['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15])', 'ShapedArray(float64[217])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,8])', 'ShapedArray(float64[217,8])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[15,10])', 'ShapedArray(float64[12])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15]) vs. ShapedArray(float32[15])', 'DIFFERENT ShapedArray(float64[217]) vs. ShapedArray(float32[217])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,8]) vs. ShapedArray(float32[15,8])', 'DIFFERENT ShapedArray(float64[217,8]) vs. ShapedArray(float32[217,8])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[15,10]) vs. ShapedArray(float32[15,10])', 'DIFFERENT ShapedArray(float64[12]) vs. ShapedArray(float32[12])']).
The model will run on CPU but takes 12 hours to fit. The GPUs sample faster, then error out. Is this truly a memory issue or is there a bug somewhere with hax specficially?
Update: This will run with 5,400-ish rows but not 17K-ish rows. This still seems like too little. Maybe the model is too big? See below for context.
"items":items,
'months':months,
'changepoints':df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]],
"yearly_components": [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)],
"obs_id":[f'{loc}_{time.year}_month_{time.month}_item_{item}' for time, loc, item in df_train.index.values]}
with pm.Model(coords=coords) as model:
A_ = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
s_ = pm.Data('s', s, mutable=True, dims=['changepoints'])
t_ = pm.Data('t', t, mutable=True, dims=['time'])
yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])
# # Slope
mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
sigma_loc_slope = pm.HalfNormal('sigma_loc_slope', sigma=0.1)
sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
offset_loc_slope = pm.Normal('offset_loc_slope', mu=0, sigma=0.1, dims=['locations'])
offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])
loc_slope = sigma_loc_slope * offset_loc_slope
item_slope = sigma_item_slope * offset_item_slope
initial_slope = mu_slope + loc_slope[location_idxs] + item_slope[item_idxs]
# Intercept
mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
sigma_loc_intercept = pm.HalfNormal('sigma_loc_intercept', sigma=0.1)
sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
offset_loc_intercept = pm.Normal('offset_loc_intercept', mu=0, sigma=0.1, dims=['locations'])
offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])
loc_intercept = sigma_loc_intercept * offset_loc_intercept
item_intercept = sigma_item_intercept * offset_item_intercept
initial_intercept = mu_intercept + loc_intercept[location_idxs] + item_intercept[item_idxs]
# Offsets
mu_delta = pm.Normal('mu_delta', 0, 0.1)
sigma_loc_delta = pm.HalfNormal('sigma_loc_delta', sigma=0.1)
sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
offset_loc_delta = pm.Normal('offset_loc_delta', mu=0, sigma=0.25, dims=['locations', 'changepoints'])
offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])
loc_delta = sigma_loc_delta * offset_loc_delta
item_delta = sigma_item_delta * offset_item_delta
delta = mu_delta + loc_delta[location_idxs, :] + item_delta[item_idxs, :]
#monthly seasonality
yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['locations', 'yearly_components'])
yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs] * yearly_beta[location_idxs, :]).sum(axis=1), dims=['obs_id'])
# Monthly Effects
beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])
intercept = initial_intercept + ((-s_ * A_)[time_idxs, :] * delta).sum(axis=1)
slope = initial_slope + (A_[time_idxs, :] * delta).sum(axis=1)
mu = pm.Deterministic('mu', intercept + slope * t_ + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
likelihood = pm.Poisson('predicted_eaches',
mu=pm.math.exp(mu),
observed=y,
dims=['obs_id'],)
trace = pymc.sampling_jax.sample_numpyro_nuts(model=model, tune=2000, chain_method='vectorized', draws = 1000, idata_kwargs=dict(log_likelihood = False))