State Space Model with Random & Deterministic Dynamics

Thanks for your very fast feedback, @ricardoV94 & @jessegrabowski !

Knowing nothing about the actual problem you want to solve I can’t give more specific advice.

As said in my first post, my goal is to model several correlated discrete time series. These are measurements related to some underlying states, which I want to determine and possibly also want to extrapolate into the future. So far, I also don’t know yet what exact model I should implement for this problem. However, for me, it would be easiest to model these time series as a non-linear state space, since it would fit best with my personal intuition about the problem and therefore would allow me to try out several different dynamics - successively from very simple to more complex ones.

The pendulum model is however totally made up, just to have a self-contained example for this thread to show some typical structures. But I would agree that this should (almost) fit into the PyMCStateSpace framework. :wink:

If your model is non-linear, you can always linearize it around the steady state (assuming its stationary – if it’s not, you can transform it to be stationary). If it’s non-guassian, you can often introduce latent states to e.g. introduce auto-correlation in the innovations.

I actually expect that the dynamics of the model will have some switching behavior (e.g. saturation or bang-bang dynamics), which cannot sufficiently be approximated by linearization. My rational to consider also non-gaussian noise is that the domain of some variables is bound, e.g. to non-negative reals only. I’m also not yet sure if the introduction of few discrete states would make sense - but for now I think it’s reasonable to consider it as a continuous problem.

Given this context, I think @ricardoV94’s suggestion would be the best way forward - to register only the random states as CustomDist and to recover the deterministic states afterwards. It is maybe not the most convenient structure for trying out many different dynamics, but likely the best I can get for now. :slight_smile: I’ve adapted the pendulum model accordingly:

def step(v, x, d, k, sigma, dt):
    v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
    x_next = v * dt + x
    return [v_next, x_next], pm.pytensorf.collect_default_updates(outputs=[v_next, x_next], inputs=[v, x, d, k, sigma, dt])

def step_d(v, x, d, k, sigma, dt):
    x_next = v * dt + x
    return [x_next], pm.pytensorf.collect_default_updates(outputs=[x_next], inputs=[v, x, d, k, sigma, dt])

def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
    [v, x], updates = pytensor.scan(
        fn=step,
        sequences=[],
        outputs_info=[v0, x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return v

def statespace_d(v, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
    x, updates = pytensor.scan(
        fn=step_d,
        sequences=[v],
        outputs_info=[x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return x

with pm.Model() as model:
    time = pd.RangeIndex(0, 200)
    dt = time.step
    n_steps = pm.Data('n_steps', time.size).astype('int')  # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data

    model.add_coord('idx_time', time)
    model.add_coord('states', ['v', 'x'])

    k = pm.Gamma('k', mu=0.02, sigma=0.01)
    d = pm.Gamma('d', mu=0.4, sigma=0.2)
    sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
    x0 = pm.Normal('x0', mu=0, sigma=0.5)
    v0 = pm.Normal('v0', mu=0, sigma=0.5)
    eps = 0.01

    v = pm.CustomDist('v', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time',))
    x = pm.Deterministic('x', statespace_d(v, x0, d, k, sigma, dt, n_steps), dims=('idx_time',))
    state = pm.Deterministic('state', pt.stack([v,x]).T, dims=('idx_time', 'states'))  # Just to keep same variable name for postprocessing

    x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))

with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
    idata_generative = pm.sample_prior_predictive()
    data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values

with pm.observe(model, {x_obs: data_x}) as model_observed:
    idata = pm.sample_prior_predictive()

However, when I sample from this model with numpyro, I get the following error after all samples have been drawn:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[30], line 2
      1 with model_observed:
----> 2     idata.extend(pm.sample(nuts_sampler='numpyro'))
      3     # idata.extend(pm.sample(nuts_sampler='blackjax'))
      4     # idata.extend(pm.sample(nuts_sampler='nutpie'))
      5     # idata.extend(pm.sample(nuts_sampler='pymc'))
      6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    720         raise ValueError(
    721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    722         )
    724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
    726             sampler=nuts_sampler,
    727             draws=draws,
    728             tune=tune,
    729             chains=chains,
    730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    731             random_seed=random_seed,
    732             initvals=initvals,
    733             model=model,
    734             var_names=var_names,
    735             progressbar=progressbar,
    736             idata_kwargs=idata_kwargs,
    737             compute_convergence_checks=compute_convergence_checks,
    738             nuts_sampler_kwargs=nuts_sampler_kwargs,
    739             **kwargs,
    740         )
    742 if isinstance(step, list):
    743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    353 elif sampler in ("numpyro", "blackjax"):
    354     import pymc.sampling.jax as pymc_jax
--> 356     idata = pymc_jax.sample_jax_nuts(
    357         draws=draws,
    358         tune=tune,
    359         chains=chains,
    360         target_accept=target_accept,
    361         random_seed=random_seed,
    362         initvals=initvals,
    363         model=model,
    364         var_names=var_names,
    365         progressbar=progressbar,
    366         nuts_sampler=sampler,
    367         idata_kwargs=idata_kwargs,
    368         compute_convergence_checks=compute_convergence_checks,
    369         **nuts_sampler_kwargs,
    370     )
    371     return idata
    373 else:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    652     log_likelihood = None
    654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
    656     jax_fn,
    657     raw_mcmc_samples,
    658     postprocessing_backend=postprocessing_backend,
    659     postprocessing_vectorize=postprocessing_vectorize,
    660     donate_samples=True,
    661 )
    662 del raw_mcmc_samples
    663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
    202     def process_fn(x):
    203         return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205     return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
    207 else:
    208     raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

    [... skipping hidden 11 frame]

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
    202 def process_fn(x):
--> 203     return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

    [... skipping hidden 6 frame]

File /tmp/tmppi3fv53j:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
     31 scalar_variable = scalar_from_tensor(tensor_variable_10)
     32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
     34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
     35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)

File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py:45, in jax_funcify_Subtensor.<locals>.subtensor(x, *ilists)
     42 if len(indices) == 1:
     43     indices = indices[0]
---> 45 return x.__getitem__(indices)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
    738 def op(self, *args):
--> 739   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
    351 def _getitem(self, item):
--> 352   return lax_numpy._rewriting_take(self, item)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   6591       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
   6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6595                unique_indices, mode, fill_value)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6601             unique_indices, mode, fill_value):
   6602   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   6604   y = arr
   6606   if fill_value is not None:

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
   6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
   6838            for elt in (i.start, i.stop, i.step)):
   6839   msg = ("Array slice indices must have static start/stop/step to be used "
   6840          "with NumPy indexing syntax. "
   6841          f"Found slice({i.start}, {i.stop}, {i.step}). "
   (...)
   6844          "dynamic_update_slice (JAX does not support dynamically sized "
   6845          "arrays within JIT compiled functions).")
-> 6846   raise IndexError(msg)
   6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
   6849 slice_shape.append(slice_size)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

With blackjax & nutpie I get similar error messages before sampling has started. The pymc NUTS is unfortunately too slow to give me some results in reasonable time.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[34], line 3
      1 with model_observed:
      2     # idata.extend(pm.sample(nuts_sampler='numpyro'))
----> 3     idata.extend(pm.sample(nuts_sampler='blackjax'))
      4     # idata.extend(pm.sample(nuts_sampler='nutpie'))
      5     # idata.extend(pm.sample(nuts_sampler='pymc'))
      6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    720         raise ValueError(
    721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    722         )
    724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
    726             sampler=nuts_sampler,
    727             draws=draws,
    728             tune=tune,
    729             chains=chains,
    730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    731             random_seed=random_seed,
    732             initvals=initvals,
    733             model=model,
    734             var_names=var_names,
    735             progressbar=progressbar,
    736             idata_kwargs=idata_kwargs,
    737             compute_convergence_checks=compute_convergence_checks,
    738             nuts_sampler_kwargs=nuts_sampler_kwargs,
    739             **kwargs,
    740         )
    742 if isinstance(step, list):
    743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    353 elif sampler in ("numpyro", "blackjax"):
    354     import pymc.sampling.jax as pymc_jax
--> 356     idata = pymc_jax.sample_jax_nuts(
    357         draws=draws,
    358         tune=tune,
    359         chains=chains,
    360         target_accept=target_accept,
    361         random_seed=random_seed,
    362         initvals=initvals,
    363         model=model,
    364         var_names=var_names,
    365         progressbar=progressbar,
    366         nuts_sampler=sampler,
    367         idata_kwargs=idata_kwargs,
    368         compute_convergence_checks=compute_convergence_checks,
    369         **nuts_sampler_kwargs,
    370     )
    371     return idata
    373 else:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    652     log_likelihood = None
    654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
    656     jax_fn,
    657     raw_mcmc_samples,
    658     postprocessing_backend=postprocessing_backend,
    659     postprocessing_vectorize=postprocessing_vectorize,
    660     donate_samples=True,
    661 )
    662 del raw_mcmc_samples
    663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
    202     def process_fn(x):
    203         return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205     return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
    207 else:
    208     raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

    [... skipping hidden 11 frame]

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
    202 def process_fn(x):
--> 203     return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

    [... skipping hidden 6 frame]

File /tmp/tmpx0adh6st:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
     31 scalar_variable = scalar_from_tensor(tensor_variable_10)
     32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
     34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
     35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)

File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py:45, in jax_funcify_Subtensor.<locals>.subtensor(x, *ilists)
     42 if len(indices) == 1:
     43     indices = indices[0]
---> 45 return x.__getitem__(indices)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
    738 def op(self, *args):
--> 739   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
    351 def _getitem(self, item):
--> 352   return lax_numpy._rewriting_take(self, item)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   6591       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
   6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6595                unique_indices, mode, fill_value)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6601             unique_indices, mode, fill_value):
   6602   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   6604   y = arr
   6606   if fill_value is not None:

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
   6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
   6838            for elt in (i.start, i.stop, i.step)):
   6839   msg = ("Array slice indices must have static start/stop/step to be used "
   6840          "with NumPy indexing syntax. "
   6841          f"Found slice({i.start}, {i.stop}, {i.step}). "
   (...)
   6844          "dynamic_update_slice (JAX does not support dynamically sized "
   6845          "arrays within JIT compiled functions).")
-> 6846   raise IndexError(msg)
   6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
   6849 slice_shape.append(slice_size)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

Any idea how this could be fixed? Right now, I assume it is a bug within the samplers. I’ve got the following versions installed:

pymc                          5.16.2
pytensor                      2.25.2
arviz                         0.19.0
numpyro                       0.15.1
blackjax                      1.2.2
nutpie                        0.13.2
jax                           0.4.30
numpy                         1.26.4

By the way, is there any method to programmatically check if a pytensor variable is random or deterministic, such as assert x.is_deterministic() and assert v.is_random()?

You’re running into all the snags that caused me to work on the statespace module!

Yeah, while looking through several other threads and notebooks, I already got your motivation for this… :wink: