Thanks for your very fast feedback, @ricardoV94 & @jessegrabowski !
Knowing nothing about the actual problem you want to solve I can’t give more specific advice.
As said in my first post, my goal is to model several correlated discrete time series. These are measurements related to some underlying states, which I want to determine and possibly also want to extrapolate into the future. So far, I also don’t know yet what exact model I should implement for this problem. However, for me, it would be easiest to model these time series as a non-linear state space, since it would fit best with my personal intuition about the problem and therefore would allow me to try out several different dynamics - successively from very simple to more complex ones.
The pendulum model is however totally made up, just to have a self-contained example for this thread to show some typical structures. But I would agree that this should (almost) fit into the PyMCStateSpace
framework.
If your model is non-linear, you can always linearize it around the steady state (assuming its stationary – if it’s not, you can transform it to be stationary). If it’s non-guassian, you can often introduce latent states to e.g. introduce auto-correlation in the innovations.
I actually expect that the dynamics of the model will have some switching behavior (e.g. saturation or bang-bang dynamics), which cannot sufficiently be approximated by linearization. My rational to consider also non-gaussian noise is that the domain of some variables is bound, e.g. to non-negative reals only. I’m also not yet sure if the introduction of few discrete states would make sense - but for now I think it’s reasonable to consider it as a continuous problem.
Given this context, I think @ricardoV94’s suggestion would be the best way forward - to register only the random states as CustomDist
and to recover the deterministic states afterwards. It is maybe not the most convenient structure for trying out many different dynamics, but likely the best I can get for now. I’ve adapted the pendulum model accordingly:
def step(v, x, d, k, sigma, dt):
v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
x_next = v * dt + x
return [v_next, x_next], pm.pytensorf.collect_default_updates(outputs=[v_next, x_next], inputs=[v, x, d, k, sigma, dt])
def step_d(v, x, d, k, sigma, dt):
x_next = v * dt + x
return [x_next], pm.pytensorf.collect_default_updates(outputs=[x_next], inputs=[v, x, d, k, sigma, dt])
def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
[v, x], updates = pytensor.scan(
fn=step,
sequences=[],
outputs_info=[v0, x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='statespace',
strict=True,
)
return v
def statespace_d(v, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
x, updates = pytensor.scan(
fn=step_d,
sequences=[v],
outputs_info=[x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='statespace',
strict=True,
)
return x
with pm.Model() as model:
time = pd.RangeIndex(0, 200)
dt = time.step
n_steps = pm.Data('n_steps', time.size).astype('int') # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data
model.add_coord('idx_time', time)
model.add_coord('states', ['v', 'x'])
k = pm.Gamma('k', mu=0.02, sigma=0.01)
d = pm.Gamma('d', mu=0.4, sigma=0.2)
sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
x0 = pm.Normal('x0', mu=0, sigma=0.5)
v0 = pm.Normal('v0', mu=0, sigma=0.5)
eps = 0.01
v = pm.CustomDist('v', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time',))
x = pm.Deterministic('x', statespace_d(v, x0, d, k, sigma, dt, n_steps), dims=('idx_time',))
state = pm.Deterministic('state', pt.stack([v,x]).T, dims=('idx_time', 'states')) # Just to keep same variable name for postprocessing
x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))
with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
idata_generative = pm.sample_prior_predictive()
data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values
with pm.observe(model, {x_obs: data_x}) as model_observed:
idata = pm.sample_prior_predictive()
However, when I sample from this model with numpyro
, I get the following error after all samples have been drawn:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[30], line 2
1 with model_observed:
----> 2 idata.extend(pm.sample(nuts_sampler='numpyro'))
3 # idata.extend(pm.sample(nuts_sampler='blackjax'))
4 # idata.extend(pm.sample(nuts_sampler='nutpie'))
5 # idata.extend(pm.sample(nuts_sampler='pymc'))
6 idata
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
720 raise ValueError(
721 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
722 )
724 with joined_blas_limiter():
--> 725 return _sample_external_nuts(
726 sampler=nuts_sampler,
727 draws=draws,
728 tune=tune,
729 chains=chains,
730 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
731 random_seed=random_seed,
732 initvals=initvals,
733 model=model,
734 var_names=var_names,
735 progressbar=progressbar,
736 idata_kwargs=idata_kwargs,
737 compute_convergence_checks=compute_convergence_checks,
738 nuts_sampler_kwargs=nuts_sampler_kwargs,
739 **kwargs,
740 )
742 if isinstance(step, list):
743 step = CompoundStep(step)
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
353 elif sampler in ("numpyro", "blackjax"):
354 import pymc.sampling.jax as pymc_jax
--> 356 idata = pymc_jax.sample_jax_nuts(
357 draws=draws,
358 tune=tune,
359 chains=chains,
360 target_accept=target_accept,
361 random_seed=random_seed,
362 initvals=initvals,
363 model=model,
364 var_names=var_names,
365 progressbar=progressbar,
366 nuts_sampler=sampler,
367 idata_kwargs=idata_kwargs,
368 compute_convergence_checks=compute_convergence_checks,
369 **nuts_sampler_kwargs,
370 )
371 return idata
373 else:
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
652 log_likelihood = None
654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
656 jax_fn,
657 raw_mcmc_samples,
658 postprocessing_backend=postprocessing_backend,
659 postprocessing_vectorize=postprocessing_vectorize,
660 donate_samples=True,
661 )
662 del raw_mcmc_samples
663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
202 def process_fn(x):
203 return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205 return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
207 else:
208 raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")
[... skipping hidden 11 frame]
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
202 def process_fn(x):
--> 203 return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
[... skipping hidden 6 frame]
File /tmp/tmppi3fv53j:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
31 scalar_variable = scalar_from_tensor(tensor_variable_10)
32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)
File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py:45, in jax_funcify_Subtensor.<locals>.subtensor(x, *ilists)
42 if len(indices) == 1:
43 indices = indices[0]
---> 45 return x.__getitem__(indices)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
738 def op(self, *args):
--> 739 return getattr(self.aval, f"_{name}")(self, *args)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
351 def _getitem(self, item):
--> 352 return lax_numpy._rewriting_take(self, item)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
6591 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6595 unique_indices, mode, fill_value)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6601 unique_indices, mode, fill_value):
6602 idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603 indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
6604 y = arr
6606 if fill_value is not None:
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
6838 for elt in (i.start, i.stop, i.step)):
6839 msg = ("Array slice indices must have static start/stop/step to be used "
6840 "with NumPy indexing syntax. "
6841 f"Found slice({i.start}, {i.stop}, {i.step}). "
(...)
6844 "dynamic_update_slice (JAX does not support dynamically sized "
6845 "arrays within JIT compiled functions).")
-> 6846 raise IndexError(msg)
6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
6849 slice_shape.append(slice_size)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
With blackjax
& nutpie
I get similar error messages before sampling has started. The pymc
NUTS is unfortunately too slow to give me some results in reasonable time.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[34], line 3
1 with model_observed:
2 # idata.extend(pm.sample(nuts_sampler='numpyro'))
----> 3 idata.extend(pm.sample(nuts_sampler='blackjax'))
4 # idata.extend(pm.sample(nuts_sampler='nutpie'))
5 # idata.extend(pm.sample(nuts_sampler='pymc'))
6 idata
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
720 raise ValueError(
721 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
722 )
724 with joined_blas_limiter():
--> 725 return _sample_external_nuts(
726 sampler=nuts_sampler,
727 draws=draws,
728 tune=tune,
729 chains=chains,
730 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
731 random_seed=random_seed,
732 initvals=initvals,
733 model=model,
734 var_names=var_names,
735 progressbar=progressbar,
736 idata_kwargs=idata_kwargs,
737 compute_convergence_checks=compute_convergence_checks,
738 nuts_sampler_kwargs=nuts_sampler_kwargs,
739 **kwargs,
740 )
742 if isinstance(step, list):
743 step = CompoundStep(step)
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
353 elif sampler in ("numpyro", "blackjax"):
354 import pymc.sampling.jax as pymc_jax
--> 356 idata = pymc_jax.sample_jax_nuts(
357 draws=draws,
358 tune=tune,
359 chains=chains,
360 target_accept=target_accept,
361 random_seed=random_seed,
362 initvals=initvals,
363 model=model,
364 var_names=var_names,
365 progressbar=progressbar,
366 nuts_sampler=sampler,
367 idata_kwargs=idata_kwargs,
368 compute_convergence_checks=compute_convergence_checks,
369 **nuts_sampler_kwargs,
370 )
371 return idata
373 else:
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
652 log_likelihood = None
654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
656 jax_fn,
657 raw_mcmc_samples,
658 postprocessing_backend=postprocessing_backend,
659 postprocessing_vectorize=postprocessing_vectorize,
660 donate_samples=True,
661 )
662 del raw_mcmc_samples
663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
202 def process_fn(x):
203 return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205 return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
207 else:
208 raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")
[... skipping hidden 11 frame]
File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
202 def process_fn(x):
--> 203 return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
[... skipping hidden 6 frame]
File /tmp/tmpx0adh6st:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
31 scalar_variable = scalar_from_tensor(tensor_variable_10)
32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)
File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py:45, in jax_funcify_Subtensor.<locals>.subtensor(x, *ilists)
42 if len(indices) == 1:
43 indices = indices[0]
---> 45 return x.__getitem__(indices)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
738 def op(self, *args):
--> 739 return getattr(self.aval, f"_{name}")(self, *args)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
351 def _getitem(self, item):
--> 352 return lax_numpy._rewriting_take(self, item)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
6591 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6595 unique_indices, mode, fill_value)
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6601 unique_indices, mode, fill_value):
6602 idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603 indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
6604 y = arr
6606 if fill_value is not None:
File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
6838 for elt in (i.start, i.stop, i.step)):
6839 msg = ("Array slice indices must have static start/stop/step to be used "
6840 "with NumPy indexing syntax. "
6841 f"Found slice({i.start}, {i.stop}, {i.step}). "
(...)
6844 "dynamic_update_slice (JAX does not support dynamically sized "
6845 "arrays within JIT compiled functions).")
-> 6846 raise IndexError(msg)
6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
6849 slice_shape.append(slice_size)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Any idea how this could be fixed? Right now, I assume it is a bug within the samplers. I’ve got the following versions installed:
pymc 5.16.2
pytensor 2.25.2
arviz 0.19.0
numpyro 0.15.1
blackjax 1.2.2
nutpie 0.13.2
jax 0.4.30
numpy 1.26.4
By the way, is there any method to programmatically check if a pytensor
variable is random or deterministic, such as assert x.is_deterministic()
and assert v.is_random()
?
You’re running into all the snags that caused me to work on the statespace module!
Yeah, while looking through several other threads and notebooks, I already got your motivation for this…