# State Space Model with Random & Deterministic Dynamics

Hey, I’m currently trying to implement a general State Space Model using PyTensor’s `scan` function.
I’ve been roughly following the approach shown in these notebooks and in this discussion.
In the end, my goal is to model several correlated discrete time series with non-linear dynamics and non-gaussian noise.
However, for now I started with a very simple MA(1) model plus explicit integration of the time steps - to get a general feeling for these models and to extend the implementation later.

So, this is the model I currently have:

``````def step(state, rho, sigma, mu):
t, x = state
t_next = t + 1
# t_next = pm.Normal.dist(mu=(t + 1), sigma=0.01)  # This would be working, but it's not physically meaningful
x_next = pm.Normal.dist(mu=(mu + rho * x), sigma=sigma)
state_next = pt.stack([t_next, x_next])
return state_next, pm.pytensorf.collect_default_updates(outputs=[state_next], inputs=[state, rho, sigma, mu])

def statespace(t0, x0, rho, sigma, mu, n_steps, *args, **kwargs):
state0 = pt.stack([t0, x0])
fn=step,
sequences=[],
outputs_info=state0,
non_sequences=[rho, sigma, mu],
n_steps=n_steps,
name='statespace',
strict=True,
)
return state

with pm.Model() as model:
time = pd.RangeIndex(1950, 2100) + 1

time = pm.Data('time', time, dims=('idx_time'))
data_t = pm.Data('data_t', df_train['Year'], dims=('idx_obs'))
data_x = pm.Data('data_x', df_train['CP Rate'], dims=('idx_obs'))

t0 = time.min() - 1
n_steps = time.size
x0 = pm.Normal('x0', mu=0, sigma=100)
mu = pm.Normal('mu', mu=0, sigma=100)
rho = pm.Normal('rho', mu=0, sigma=1)
sigma = pm.Gamma('sigma', mu=10, sigma=10)
eps = 0.1

state = pm.CustomDist('state', t0, x0, rho, sigma, mu, n_steps, dist=statespace, dims=('idx_time', 'states'))
t = state[:, 0]
x = state[:, 1]
x_obs = pm.Normal('x_obs', mu=pt.eq(data_t[:, None], time) @ x, sigma=eps, observed=data_x, dims=('idx_obs'))
``````

Prior prediction work fine. However, during posterior inference with `pm.sample()` I get the following error message:

``````---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[88], line 2
1 with model:
----> 2     idata.extend(pm.sample(nuts_sampler='numpyro'))

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:716, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
713         auto_nuts_init = False
715 initial_points = None
--> 716 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
718 if nuts_sampler != "pymc":
719     if not isinstance(step, NUTS):

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:215, in assign_step_methods(model, step, methods, step_kwargs)
213 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
214 selected_steps: dict[type[BlockedStep], list] = {}
--> 215 model_logp = model.logp()
217 for var in model.value_vars:
218     if var not in assigned_vars:
219         # determine if a gradient can be computed

File /opt/conda/lib/python3.12/site-packages/pymc/model/core.py:742, in Model.logp(self, vars, jacobian, sum)
740 rv_logps: list[TensorVariable] = []
741 if rvs:
--> 742     rv_logps = transformed_conditional_logp(
743         rvs=rvs,
744         rvs_to_values=self.rvs_to_values,
745         rvs_to_transforms=self.rvs_to_transforms,
746         jacobian=jacobian,
747     )
748     assert isinstance(rv_logps, list)
750 # Replace random variables by their value variables in potential terms

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:611, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
608     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore
610 kwargs.setdefault("warn_rvs", False)
--> 611 temp_logp_terms = conditional_logp(
612     rvs_to_values,
613     extra_rewrites=transform_rewrite,
614     use_jacobian=jacobian,
615     **kwargs,
616 )
618 # The function returns the logp for every single value term we provided to it.
619 # This includes the extra values we plugged in above, so we filter those we
620 # actually wanted in the same order they were given in.
621 logp_terms = {}

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:541, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
538 q_values = remapped_vars[: len(q_values)]
539 q_rv_inputs = remapped_vars[len(q_values) :]
--> 541 q_logprob_vars = _logprob(
542     node.op,
543     q_values,
544     *q_rv_inputs,
545     **kwargs,
546 )
548 if not isinstance(q_logprob_vars, list | tuple):
549     q_logprob_vars = [q_logprob_vars]

File /opt/conda/lib/python3.12/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906     raise TypeError(f'{funcname} requires at least '
907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/scan.py:309, in logprob_ScanRV(op, values, name, *inputs, **kwargs)
306 scan_args = ScanArgs.from_node(new_node)
307 rv_outer_outs = get_random_outer_outputs(scan_args)
--> 309 var_indices, rv_vars, io_vars = zip(*rv_outer_outs)
310 value_map = {_rv: _val for _rv, _val in zip(rv_vars, values)}
312 def create_inner_out_logp(value_map: dict[TensorVariable, TensorVariable]) -> TensorVariable:

ValueError: not enough values to unpack (expected 3, got 0)
``````

I assume that this error is caused by the fact that the first state `t` is deterministic, but both states are fed into `CustomDist` together. E.g. if I add some random noise to `t` within the `step` function (see commented code above) everything is working fine. (But it’s physically not meaningful - time is not random…)

I have no idea anymore how to refactor this model, so that random & deterministic state dynamics can be modelled simultaneously. During some refactoring attempts I was also running into this issue. So if anyone could point me to the direction, I would be more then happy.

P.S.: I know that this simple model could be easier be implemented with `pm.AR` or Statespace Models. But as mentioned, this is just the starting point for more complex models with similar structure.

You have to model your deterministics separately from the RVs. CustomDist can’t handle deterministic (non measurable outputs).

In your case it means you will have two scans in your model. The one inside CustomDist and the one that computes t (although you could replace t by `pt.arange` but I imagine you may have more complex deterministics in mind).

Note this does not incur performance penalty as during sampling separate functions are used for computing deterministics and logp anyway.

2 Likes

Hey @ricardoV94, thanks for your fast reply! Yes, you are right, separating both states and performing two scans would work out with this model.

However, for many state space models, the states are depending on each other, which makes such a separation in general very difficult. My previous model is maybe a bit too simple to illustrate this, so I tried to quickly assemble another model with interdependent random and deterministic states:

``````def step(state, d, k, sigma, dt):
v, x = state
v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
x_next = v * dt + x  # This is causing a ValueError, as described above
# x_next = v * dt + pm.Normal.dist(mu=x, sigma=0.001)  # Again, this would work - but it's not physical
state_next = pt.stack([v_next, x_next])
return state_next, pm.pytensorf.collect_default_updates(outputs=[state_next], inputs=[state, d, k, sigma, dt])

def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
state0 = pt.stack([v0, x0])
fn=step,
sequences=[],
outputs_info=state0,
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='statespace',
strict=True,
)
return state

with pm.Model() as model:
time = pd.RangeIndex(0, 200)
dt = time.step
n_steps = pm.Data('n_steps', time.size).astype('int')  # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data

k = pm.Gamma('k', mu=0.02, sigma=0.01)
d = pm.Gamma('d', mu=0.4, sigma=0.2)
sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
x0 = pm.Normal('x0', mu=0, sigma=0.5)
v0 = pm.Normal('v0', mu=0, sigma=0.5)
eps = 0.01

state = pm.CustomDist('state', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time', 'states'))
v = state[:, 0]
x = state[:, 1]

x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))

# Generate fake data
with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
idata_generative = pm.sample_prior_predictive()
data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values

# Sampling
with pm.observe(model, {x_obs: data_x}) as model_observed:
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(nuts_sampler='numpyro'))
idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=False)
``````

This model is more or less representing a pendulum damped by turbulent drag. According to kinematics, position `x` and velocity `v` have a deterministic relationship, but the incremental drag force is random. As you can see, `v_next` and `x_next` are depending on the pervious values of both states - So, a simple separation into two scans would not be possible here…

I had also in mind to separate the states after applying scan - but so far without any success:

``````    state = statespace(v0, x0, d, k, sigma, dt, n_steps)
v = state[:, 0]
x = state[:, 1]
model.register_rv(v, 'v', dims=('idx_time',))
x = pm.Deterministic('x', x, dims=('idx_time',))
...
``````

Your custom dist scan can use/compute deterministics but can’t return them. So you shouldn’t stack the two, but instead return only the non-deterministic sequence. You can always recover the deterministic sequence with a second scan outside the customdist that takes initial value and the non deterministic sequence as input (which is what PyMC would always need).

Unfortunately there’s no machinery to derive the deterministic function from the measurable one automatically so you have to figure it out.

For instance, say you deterministically carry around the two step difference when generating the random sequence, something like:

``````def dist(x0, ...):
def step(x_tm1, diff_tm1):
x_t = pm.Normal.dist((x_tm1 + diff_tm1) / 2)
diff_t = x_t - x_tm1
return (x_t, diff_t), update_t...

retun x_seq  # don't return diff_t

with pm.Model() as m:
x0 = ...
x_seq = pm.CustomDist(”x_seq”, x0, dist=dist, ...)
# recompute diff_t for access later
diff_t = pm.Deterministic("diff_t", pt.diff(x_seq))
``````

Your custom dist can return (and sample) x sequences. If you want access to the underlying deterministic diffs you have to recompute them either with a determinsitic scan or in this case a simpler deterministic operation.

The problem with your solution is that during inverse mcmc sampling you would need PyMC to figure out how to make `x` a deterministic function of the sampled `v` which is not something that’s encoded in your generative scan which builds `x` and `v` simultaneously. PyMC could in principle derive it (just like it derives the logp of `v` for inverse sampling), but there’s no easy way to achieve that right now.

1 Like

You’re running into all the snags that caused me to work on the statespace module!

If the actual model you’re going for is linear and guassian, you can implement a custom statespace class to handle arbitrary models. Here’s the tutorial notebook on custom classes. Plus you get all the post estimation stuff for free! I believe the pendulum example falls neatly within the available framework, for example.

If your model is non-linear, you can always linearize it around the steady state (assuming its stationary – if it’s not, you can transform it to be stationary). If it’s non-guassian, you can often introduce latent states to e.g. introduce auto-correlation in the innovations.

Worst case, you can open a PR for an extended or unscented kalman filter

Knowing nothing about the actual problem you want to solve I can’t give more specific advice. But I recommend you give the statespace classes a look.

1 Like

Thanks for your very fast feedback, @ricardoV94 & @jessegrabowski !

Knowing nothing about the actual problem you want to solve I can’t give more specific advice.

As said in my first post, my goal is to model several correlated discrete time series. These are measurements related to some underlying states, which I want to determine and possibly also want to extrapolate into the future. So far, I also don’t know yet what exact model I should implement for this problem. However, for me, it would be easiest to model these time series as a non-linear state space, since it would fit best with my personal intuition about the problem and therefore would allow me to try out several different dynamics - successively from very simple to more complex ones.

The pendulum model is however totally made up, just to have a self-contained example for this thread to show some typical structures. But I would agree that this should (almost) fit into the `PyMCStateSpace` framework.

If your model is non-linear, you can always linearize it around the steady state (assuming its stationary – if it’s not, you can transform it to be stationary). If it’s non-guassian, you can often introduce latent states to e.g. introduce auto-correlation in the innovations.

I actually expect that the dynamics of the model will have some switching behavior (e.g. saturation or bang-bang dynamics), which cannot sufficiently be approximated by linearization. My rational to consider also non-gaussian noise is that the domain of some variables is bound, e.g. to non-negative reals only. I’m also not yet sure if the introduction of few discrete states would make sense - but for now I think it’s reasonable to consider it as a continuous problem.

Given this context, I think @ricardoV94’s suggestion would be the best way forward - to register only the random states as `CustomDist` and to recover the deterministic states afterwards. It is maybe not the most convenient structure for trying out many different dynamics, but likely the best I can get for now. I’ve adapted the pendulum model accordingly:

``````def step(v, x, d, k, sigma, dt):
v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
x_next = v * dt + x
return [v_next, x_next], pm.pytensorf.collect_default_updates(outputs=[v_next, x_next], inputs=[v, x, d, k, sigma, dt])

def step_d(v, x, d, k, sigma, dt):
x_next = v * dt + x
return [x_next], pm.pytensorf.collect_default_updates(outputs=[x_next], inputs=[v, x, d, k, sigma, dt])

def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
fn=step,
sequences=[],
outputs_info=[v0, x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='statespace',
strict=True,
)
return v

def statespace_d(v, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
fn=step_d,
sequences=[v],
outputs_info=[x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='statespace',
strict=True,
)
return x

with pm.Model() as model:
time = pd.RangeIndex(0, 200)
dt = time.step
n_steps = pm.Data('n_steps', time.size).astype('int')  # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data

k = pm.Gamma('k', mu=0.02, sigma=0.01)
d = pm.Gamma('d', mu=0.4, sigma=0.2)
sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
x0 = pm.Normal('x0', mu=0, sigma=0.5)
v0 = pm.Normal('v0', mu=0, sigma=0.5)
eps = 0.01

v = pm.CustomDist('v', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time',))
x = pm.Deterministic('x', statespace_d(v, x0, d, k, sigma, dt, n_steps), dims=('idx_time',))
state = pm.Deterministic('state', pt.stack([v,x]).T, dims=('idx_time', 'states'))  # Just to keep same variable name for postprocessing

x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))

with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
idata_generative = pm.sample_prior_predictive()
data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values

with pm.observe(model, {x_obs: data_x}) as model_observed:
idata = pm.sample_prior_predictive()
``````

However, when I sample from this model with `numpyro`, I get the following error after all samples have been drawn:

``````---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[30], line 2
1 with model_observed:
----> 2     idata.extend(pm.sample(nuts_sampler='numpyro'))
3     # idata.extend(pm.sample(nuts_sampler='blackjax'))
4     # idata.extend(pm.sample(nuts_sampler='nutpie'))
5     # idata.extend(pm.sample(nuts_sampler='pymc'))
6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
720         raise ValueError(
721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
722         )
724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
726             sampler=nuts_sampler,
727             draws=draws,
728             tune=tune,
729             chains=chains,
730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
731             random_seed=random_seed,
732             initvals=initvals,
733             model=model,
734             var_names=var_names,
735             progressbar=progressbar,
736             idata_kwargs=idata_kwargs,
737             compute_convergence_checks=compute_convergence_checks,
738             nuts_sampler_kwargs=nuts_sampler_kwargs,
739             **kwargs,
740         )
742 if isinstance(step, list):
743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
353 elif sampler in ("numpyro", "blackjax"):
354     import pymc.sampling.jax as pymc_jax
--> 356     idata = pymc_jax.sample_jax_nuts(
357         draws=draws,
358         tune=tune,
359         chains=chains,
360         target_accept=target_accept,
361         random_seed=random_seed,
362         initvals=initvals,
363         model=model,
364         var_names=var_names,
365         progressbar=progressbar,
366         nuts_sampler=sampler,
367         idata_kwargs=idata_kwargs,
368         compute_convergence_checks=compute_convergence_checks,
369         **nuts_sampler_kwargs,
370     )
371     return idata
373 else:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
652     log_likelihood = None
654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
656     jax_fn,
657     raw_mcmc_samples,
658     postprocessing_backend=postprocessing_backend,
659     postprocessing_vectorize=postprocessing_vectorize,
660     donate_samples=True,
661 )
662 del raw_mcmc_samples
663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
202     def process_fn(x):
203         return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205     return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
207 else:
208     raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

[... skipping hidden 11 frame]

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
202 def process_fn(x):
--> 203     return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

[... skipping hidden 6 frame]

File /tmp/tmppi3fv53j:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
31 scalar_variable = scalar_from_tensor(tensor_variable_10)
32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)

42 if len(indices) == 1:
43     indices = indices[0]
---> 45 return x.__getitem__(indices)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
738 def op(self, *args):
--> 739   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
351 def _getitem(self, item):
--> 352   return lax_numpy._rewriting_take(self, item)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
6591       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6595                unique_indices, mode, fill_value)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6601             unique_indices, mode, fill_value):
6602   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
6604   y = arr
6606   if fill_value is not None:

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
6838            for elt in (i.start, i.stop, i.step)):
6839   msg = ("Array slice indices must have static start/stop/step to be used "
6840          "with NumPy indexing syntax. "
6841          f"Found slice({i.start}, {i.stop}, {i.step}). "
(...)
6844          "dynamic_update_slice (JAX does not support dynamically sized "
6845          "arrays within JIT compiled functions).")
-> 6846   raise IndexError(msg)
6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
6849 slice_shape.append(slice_size)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
``````

With `blackjax` & `nutpie` I get similar error messages before sampling has started. The `pymc` NUTS is unfortunately too slow to give me some results in reasonable time.

``````---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[34], line 3
1 with model_observed:
2     # idata.extend(pm.sample(nuts_sampler='numpyro'))
----> 3     idata.extend(pm.sample(nuts_sampler='blackjax'))
4     # idata.extend(pm.sample(nuts_sampler='nutpie'))
5     # idata.extend(pm.sample(nuts_sampler='pymc'))
6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
720         raise ValueError(
721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
722         )
724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
726             sampler=nuts_sampler,
727             draws=draws,
728             tune=tune,
729             chains=chains,
730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
731             random_seed=random_seed,
732             initvals=initvals,
733             model=model,
734             var_names=var_names,
735             progressbar=progressbar,
736             idata_kwargs=idata_kwargs,
737             compute_convergence_checks=compute_convergence_checks,
738             nuts_sampler_kwargs=nuts_sampler_kwargs,
739             **kwargs,
740         )
742 if isinstance(step, list):
743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
353 elif sampler in ("numpyro", "blackjax"):
354     import pymc.sampling.jax as pymc_jax
--> 356     idata = pymc_jax.sample_jax_nuts(
357         draws=draws,
358         tune=tune,
359         chains=chains,
360         target_accept=target_accept,
361         random_seed=random_seed,
362         initvals=initvals,
363         model=model,
364         var_names=var_names,
365         progressbar=progressbar,
366         nuts_sampler=sampler,
367         idata_kwargs=idata_kwargs,
368         compute_convergence_checks=compute_convergence_checks,
369         **nuts_sampler_kwargs,
370     )
371     return idata
373 else:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
652     log_likelihood = None
654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
656     jax_fn,
657     raw_mcmc_samples,
658     postprocessing_backend=postprocessing_backend,
659     postprocessing_vectorize=postprocessing_vectorize,
660     donate_samples=True,
661 )
662 del raw_mcmc_samples
663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
202     def process_fn(x):
203         return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205     return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
207 else:
208     raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

[... skipping hidden 11 frame]

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
202 def process_fn(x):
--> 203     return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

[... skipping hidden 6 frame]

File /tmp/tmpx0adh6st:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
31 scalar_variable = scalar_from_tensor(tensor_variable_10)
32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)

42 if len(indices) == 1:
43     indices = indices[0]
---> 45 return x.__getitem__(indices)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
738 def op(self, *args):
--> 739   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
351 def _getitem(self, item):
--> 352   return lax_numpy._rewriting_take(self, item)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
6591       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6595                unique_indices, mode, fill_value)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
6601             unique_indices, mode, fill_value):
6602   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
6604   y = arr
6606   if fill_value is not None:

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
6838            for elt in (i.start, i.stop, i.step)):
6839   msg = ("Array slice indices must have static start/stop/step to be used "
6840          "with NumPy indexing syntax. "
6841          f"Found slice({i.start}, {i.stop}, {i.step}). "
(...)
6844          "dynamic_update_slice (JAX does not support dynamically sized "
6845          "arrays within JIT compiled functions).")
-> 6846   raise IndexError(msg)
6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
6849 slice_shape.append(slice_size)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
``````

Any idea how this could be fixed? Right now, I assume it is a bug within the samplers. I’ve got the following versions installed:

``````pymc                          5.16.2
pytensor                      2.25.2
arviz                         0.19.0
numpyro                       0.15.1
blackjax                      1.2.2
nutpie                        0.13.2
jax                           0.4.30
numpy                         1.26.4
``````

By the way, is there any method to programmatically check if a `pytensor` variable is random or deterministic, such as `assert x.is_deterministic()` and `assert v.is_random()`?

You’re running into all the snags that caused me to work on the statespace module!

JAX is very finicky about shapes, so you can try to use `freeze_data_and_dims` before attempting to sample with numpyro:

https://www.pymc.io/projects/docs/en/v5.13.1/api/model/generated/pymc.model.transform.optimization.freeze_dims_and_data.html

For nutpie that’s more surprising, since it uses the NUMBA backend, could you post the traceback?

Thanks for your hint to `freeze_data_and_dims`! Unfortunately, for `numpyro` or `blackjax` I’m still getting the same error messages as above.

However, `nutpie` is now running when I use `freeze_data_and_dims`! The results are looking not so good - but this could also be caused by the model structure, which I could maybe further improve.

Without `freeze_data_and_dims` I get the following error:

``````---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[33], line 4
1 with model_observed:
2     # idata.extend(pm.sample(nuts_sampler='numpyro'))
3     # idata.extend(pm.sample(nuts_sampler='blackjax'))
----> 4     idata.extend(pm.sample(nuts_sampler='nutpie'))
5     # idata.extend(pm.sample(nuts_sampler='pymc'))
6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
720         raise ValueError(
721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
722         )
724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
726             sampler=nuts_sampler,
727             draws=draws,
728             tune=tune,
729             chains=chains,
730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
731             random_seed=random_seed,
732             initvals=initvals,
733             model=model,
734             var_names=var_names,
735             progressbar=progressbar,
736             idata_kwargs=idata_kwargs,
737             compute_convergence_checks=compute_convergence_checks,
738             nuts_sampler_kwargs=nuts_sampler_kwargs,
739             **kwargs,
740         )
742 if isinstance(step, list):
743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:307, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
302 if var_names is not None:
303     warnings.warn(
304         "`var_names` are currently ignored by the nutpie sampler",
305         UserWarning,
306     )
--> 307 compiled_model = nutpie.compile_pymc_model(model)
308 t_start = time.time()
309 idata = nutpie.sample(
310     compiled_model,
311     draws=draws,
(...)
317     **nuts_sampler_kwargs,
318 )

File /opt/conda/lib/python3.12/site-packages/nutpie/compile_pymc.py:391, in compile_pymc_model(model, backend, gradient_backend, **kwargs)
388     backend = "numba"
390 if backend.lower() == "numba":
--> 391     return _compile_pymc_model_numba(model, **kwargs)
392 elif backend.lower() == "jax":
393     return _compile_pymc_model_jax(
395     )

File /opt/conda/lib/python3.12/site-packages/nutpie/compile_pymc.py:207, in _compile_pymc_model_numba(model, **kwargs)
200 with warnings.catch_warnings():
201     warnings.filterwarnings(
202         "ignore",
203         message="Cannot cache compiled function .* as it uses dynamic globals",
204         category=numba.NumbaWarning,
205     )
--> 207     logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
209 expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]
210 expand_numba_raw, c_sig_expand = _make_c_expand_func(
211     n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data
212 )

File /opt/conda/lib/python3.12/site-packages/numba/core/decorators.py:275, in cfunc.<locals>.wrapper(func)
273 if cache:
274     res.enable_caching()
--> 275 res.compile()
276 return res

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/ccallback.py:68, in CFunc.compile(self)
66                                  self._targetdescr.target_context)
67 if cres is None:
---> 68     cres = self._compile_uncached()
70 else:

File /opt/conda/lib/python3.12/site-packages/numba/core/ccallback.py:82, in CFunc._compile_uncached(self)
79 sig = self._sig
81 # Compile native function as well as cfunc wrapper
---> 82 return self._compiler.compile(sig.args, sig.return_type)

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
79 def compile(self, args, return_type):
---> 80     status, retval = self._compile_cached(args, return_type)
81     if status:
82         return retval

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
91     pass
93 try:
---> 94     retval = self._compile_core(args, return_type)
95 except errors.TypingError as e:
96     self._failed_cache[key] = e

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
104 flags = self._customize_flags(flags)
106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
108                               self.targetdescr.target_context,
109                               impl,
110                               args=args, return_type=return_type,
111                               flags=flags, locals=self.locals,
112                               pipeline_class=self.pipeline_class)
113 # Check typing error if object mode is used
114 if cres.typing_error is not None and not flags.enable_pyobject:

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
720 """Compiler entry point
721
722 Parameter
(...)
740     compiler pipeline
741 """
742 pipeline = pipeline_class(typingctx, targetctx, library,
743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:438, in CompilerBase.compile_extra(self, func)
436 self.state.lifted = ()
437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:506, in CompilerBase._compile_bytecode(self)
502 """
503 Populate and run pipeline for bytecode input
504 """
505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:481, in CompilerBase._compile_core(self)
478 except Exception as e:
479     if (utils.use_new_style_errors() and not
480             isinstance(e, errors.NumbaError)):
--> 481         raise e
483     self.state.status.fail_reason = e
484     if is_final_pipeline:

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:472, in CompilerBase._compile_core(self)
470 res = None
471 try:
--> 472     pm.run(self.state)
473     if self.state.cr is not None:
474         break

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:364, in PassManager.run(self, state)
361 except Exception as e:
362     if (utils.use_new_style_errors() and not
363             isinstance(e, errors.NumbaError)):
--> 364         raise e
365     msg = "Failed in %s mode pipeline (step: %s)" % \
366         (self.pipeline_name, pass_desc)
367     patched_exception = self._patch_error(msg, e)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
357 else:
358     raise BaseException("Legacy pass in use")

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309     mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313     mutated |= check(pss.run_finalizer, internal_state)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
274     if mangled not in (True, False):
275         msg = ("CompilerPass implementations should return True/False. "
276                "CompilerPass with name '%s' did not.")

File /opt/conda/lib/python3.12/site-packages/numba/core/untyped_passes.py:1731, in LiteralUnroll.run_pass(self, state)
1730 pm.finalize()
-> 1731 pm.run(state)
1732 return True

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:364, in PassManager.run(self, state)
361 except Exception as e:
362     if (utils.use_new_style_errors() and not
363             isinstance(e, errors.NumbaError)):
--> 364         raise e
365     msg = "Failed in %s mode pipeline (step: %s)" % \
366         (self.pipeline_name, pass_desc)
367     patched_exception = self._patch_error(msg, e)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
357 else:
358     raise BaseException("Legacy pass in use")

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309     mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313     mutated |= check(pss.run_finalizer, internal_state)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
274     if mangled not in (True, False):
275         msg = ("CompilerPass implementations should return True/False. "
276                "CompilerPass with name '%s' did not.")

File /opt/conda/lib/python3.12/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
106 """
107 Type inference and legalization
108 """
109 with fallback_context(state, 'Function "%s" failed type inference'
110                       % (state.func_id.func_name,)):
111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
113         state.typingctx,
114         state.targetctx,
115         state.func_ir,
116         state.args,
117         state.return_type,
118         state.locals,
119         raise_errors=self._raise_errors)
120     state.typemap = typemap
121     # save errors in case of partial typing

File /opt/conda/lib/python3.12/site-packages/numba/core/typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
91     infer.build_constraint()
92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
96 return _TypingResults(typemap, restype, calltypes, errs)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
1080 oldtoken = newtoken
1081 # Errors can appear when the type set is incomplete; only
1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
1084 newtoken = self.get_state_token()
1085 self.debug.propagate_finished()

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:182, in ConstraintNetwork.propagate(self, typeinfer)
180     errors.append(utils.chain_exception(new_exc, e))
181 elif utils.use_new_style_errors():
--> 182     raise e
183 else:
184     msg = ("Unknown CAPTURED_ERRORS style: "
185            f"'{config.CAPTURED_ERRORS}'.")

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
158                                        lineno=loc.line):
159     try:
--> 160         constraint(typeinfer)
161     except ForceLiteralArg as e:
162         errors.append(e)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
581     fnty = typevars[self.func].getone()
582 with new_error_context("resolving callee type: {0}", fnty):
--> 583     self.resolve(typeinfer, typevars, fnty)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
604     fnty = fnty.instance_type
605 try:
--> 606     sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
...

AttributeError: module 'numpy' has no attribute 'bool'.
`np.bool` was a deprecated alias for the builtin `bool`. To avoid this error in existing code, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
``````

Interesting, can you provide a minimal reproducible example? Sounds like something we should be able to patch

Sure, it is basically the example above plus some `import`s:

``````import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import pytensor

def step(v, x, d, k, sigma, dt):
v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
x_next = v * dt + x
return [v_next, x_next], pm.pytensorf.collect_default_updates(outputs=[v_next, x_next], inputs=[v, x, d, k, sigma, dt])

def step_d(v, x, d, k, sigma, dt):
x_next = v * dt + x
return [x_next], pm.pytensorf.collect_default_updates(outputs=[x_next], inputs=[v, x, d, k, sigma, dt])

def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
fn=step,
sequences=[],
outputs_info=[v0, x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='statespace',
strict=True,
)
return v

def statespace_d(v, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
fn=step_d,
sequences=[v],
outputs_info=[x0],
non_sequences=[d, k, sigma, dt],
n_steps=n_steps,
name='statespace',
strict=True,
)
return x

with pm.Model() as model:
time = pd.RangeIndex(0, 200)
dt = time.step
n_steps = pm.Data('n_steps', time.size).astype('int')  # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data

k = pm.Gamma('k', mu=0.02, sigma=0.01)
d = pm.Gamma('d', mu=0.4, sigma=0.2)
sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
x0 = pm.Normal('x0', mu=0, sigma=0.5)
v0 = pm.Normal('v0', mu=0, sigma=0.5)
eps = 0.01

v = pm.CustomDist('v', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time',))
x = pm.Deterministic('x', statespace_d(v, x0, d, k, sigma, dt, n_steps), dims=('idx_time',))
state = pm.Deterministic('state', pt.stack([v,x]).T, dims=('idx_time', 'states'))  # Just to keep same variable name for postprocessing

x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))

with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
idata_generative = pm.sample_prior_predictive()
data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values

with pm.observe(model, {x_obs: data_x}) as model_observed:
idata = pm.sample(nuts_sampler='nutpie')
``````