State Space Model with Random & Deterministic Dynamics

Hey, I’m currently trying to implement a general State Space Model using PyTensor’s scan function.
I’ve been roughly following the approach shown in these notebooks and in this discussion.
In the end, my goal is to model several correlated discrete time series with non-linear dynamics and non-gaussian noise.
However, for now I started with a very simple MA(1) model plus explicit integration of the time steps - to get a general feeling for these models and to extend the implementation later.

So, this is the model I currently have:

def step(state, rho, sigma, mu):
    t, x = state
    t_next = t + 1
    # t_next = pm.Normal.dist(mu=(t + 1), sigma=0.01)  # This would be working, but it's not physically meaningful
    x_next = pm.Normal.dist(mu=(mu + rho * x), sigma=sigma)
    state_next = pt.stack([t_next, x_next])
    return state_next, pm.pytensorf.collect_default_updates(outputs=[state_next], inputs=[state, rho, sigma, mu])

def statespace(t0, x0, rho, sigma, mu, n_steps, *args, **kwargs):
    state0 = pt.stack([t0, x0])
    state, updates = pytensor.scan(
        fn=step,
        sequences=[],
        outputs_info=state0,
        non_sequences=[rho, sigma, mu],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return state

with pm.Model() as model:
    time = pd.RangeIndex(1950, 2100) + 1
    model.add_coord('idx_time', time)
    model.add_coord('idx_obs', df_train.index)
    model.add_coord('states', ['t', 'x'])
    
    time = pm.Data('time', time, dims=('idx_time'))
    data_t = pm.Data('data_t', df_train['Year'], dims=('idx_obs'))
    data_x = pm.Data('data_x', df_train['CP Rate'], dims=('idx_obs'))

    t0 = time.min() - 1
    n_steps = time.size
    x0 = pm.Normal('x0', mu=0, sigma=100)
    mu = pm.Normal('mu', mu=0, sigma=100)
    rho = pm.Normal('rho', mu=0, sigma=1)
    sigma = pm.Gamma('sigma', mu=10, sigma=10)
    eps = 0.1

    state = pm.CustomDist('state', t0, x0, rho, sigma, mu, n_steps, dist=statespace, dims=('idx_time', 'states'))
    t = state[:, 0]
    x = state[:, 1]
    x_obs = pm.Normal('x_obs', mu=pt.eq(data_t[:, None], time) @ x, sigma=eps, observed=data_x, dims=('idx_obs'))

Prior prediction work fine. However, during posterior inference with pm.sample() I get the following error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[88], line 2
      1 with model:
----> 2     idata.extend(pm.sample(nuts_sampler='numpyro'))

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:716, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    713         auto_nuts_init = False
    715 initial_points = None
--> 716 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    718 if nuts_sampler != "pymc":
    719     if not isinstance(step, NUTS):

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:215, in assign_step_methods(model, step, methods, step_kwargs)
    213 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    214 selected_steps: dict[type[BlockedStep], list] = {}
--> 215 model_logp = model.logp()
    217 for var in model.value_vars:
    218     if var not in assigned_vars:
    219         # determine if a gradient can be computed

File /opt/conda/lib/python3.12/site-packages/pymc/model/core.py:742, in Model.logp(self, vars, jacobian, sum)
    740 rv_logps: list[TensorVariable] = []
    741 if rvs:
--> 742     rv_logps = transformed_conditional_logp(
    743         rvs=rvs,
    744         rvs_to_values=self.rvs_to_values,
    745         rvs_to_transforms=self.rvs_to_transforms,
    746         jacobian=jacobian,
    747     )
    748     assert isinstance(rv_logps, list)
    750 # Replace random variables by their value variables in potential terms

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:611, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    608     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore
    610 kwargs.setdefault("warn_rvs", False)
--> 611 temp_logp_terms = conditional_logp(
    612     rvs_to_values,
    613     extra_rewrites=transform_rewrite,
    614     use_jacobian=jacobian,
    615     **kwargs,
    616 )
    618 # The function returns the logp for every single value term we provided to it.
    619 # This includes the extra values we plugged in above, so we filter those we
    620 # actually wanted in the same order they were given in.
    621 logp_terms = {}

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/basic.py:541, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    538 q_values = remapped_vars[: len(q_values)]
    539 q_rv_inputs = remapped_vars[len(q_values) :]
--> 541 q_logprob_vars = _logprob(
    542     node.op,
    543     q_values,
    544     *q_rv_inputs,
    545     **kwargs,
    546 )
    548 if not isinstance(q_logprob_vars, list | tuple):
    549     q_logprob_vars = [q_logprob_vars]

File /opt/conda/lib/python3.12/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File /opt/conda/lib/python3.12/site-packages/pymc/logprob/scan.py:309, in logprob_ScanRV(op, values, name, *inputs, **kwargs)
    306 scan_args = ScanArgs.from_node(new_node)
    307 rv_outer_outs = get_random_outer_outputs(scan_args)
--> 309 var_indices, rv_vars, io_vars = zip(*rv_outer_outs)
    310 value_map = {_rv: _val for _rv, _val in zip(rv_vars, values)}
    312 def create_inner_out_logp(value_map: dict[TensorVariable, TensorVariable]) -> TensorVariable:

ValueError: not enough values to unpack (expected 3, got 0)

I assume that this error is caused by the fact that the first state t is deterministic, but both states are fed into CustomDist together. E.g. if I add some random noise to t within the step function (see commented code above) everything is working fine. (But it’s physically not meaningful - time is not random…)

I have no idea anymore how to refactor this model, so that random & deterministic state dynamics can be modelled simultaneously. During some refactoring attempts I was also running into this issue. So if anyone could point me to the direction, I would be more then happy. :slightly_smiling_face:

P.S.: I know that this simple model could be easier be implemented with pm.AR or Statespace Models. But as mentioned, this is just the starting point for more complex models with similar structure.

You have to model your deterministics separately from the RVs. CustomDist can’t handle deterministic (non measurable outputs).

In your case it means you will have two scans in your model. The one inside CustomDist and the one that computes t (although you could replace t by pt.arange but I imagine you may have more complex deterministics in mind).

Note this does not incur performance penalty as during sampling separate functions are used for computing deterministics and logp anyway.

2 Likes

Hey @ricardoV94, thanks for your fast reply! Yes, you are right, separating both states and performing two scans would work out with this model.

However, for many state space models, the states are depending on each other, which makes such a separation in general very difficult. My previous model is maybe a bit too simple to illustrate this, so I tried to quickly assemble another model with interdependent random and deterministic states:

def step(state, d, k, sigma, dt):
    v, x = state
    v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
    x_next = v * dt + x  # This is causing a ValueError, as described above
    # x_next = v * dt + pm.Normal.dist(mu=x, sigma=0.001)  # Again, this would work - but it's not physical
    state_next = pt.stack([v_next, x_next])
    return state_next, pm.pytensorf.collect_default_updates(outputs=[state_next], inputs=[state, d, k, sigma, dt])

def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
    state0 = pt.stack([v0, x0])
    state, updates = pytensor.scan(
        fn=step,
        sequences=[],
        outputs_info=state0,
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return state

with pm.Model() as model:
    time = pd.RangeIndex(0, 200)
    dt = time.step
    n_steps = pm.Data('n_steps', time.size).astype('int')  # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data

    model.add_coord('idx_time', time)
    model.add_coord('states', ['v', 'x'])

    k = pm.Gamma('k', mu=0.02, sigma=0.01)
    d = pm.Gamma('d', mu=0.4, sigma=0.2)
    sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
    x0 = pm.Normal('x0', mu=0, sigma=0.5)
    v0 = pm.Normal('v0', mu=0, sigma=0.5)
    eps = 0.01

    state = pm.CustomDist('state', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time', 'states'))
    v = state[:, 0]
    x = state[:, 1]

    x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))

# Generate fake data
with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
    idata_generative = pm.sample_prior_predictive()
    data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values

# Sampling
with pm.observe(model, {x_obs: data_x}) as model_observed:
    idata = pm.sample_prior_predictive()
    idata.extend(pm.sample(nuts_sampler='numpyro'))
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=False)

This model is more or less representing a pendulum damped by turbulent drag. According to kinematics, position x and velocity v have a deterministic relationship, but the incremental drag force is random. As you can see, v_next and x_next are depending on the pervious values of both states - So, a simple separation into two scans would not be possible here…

I had also in mind to separate the states after applying scan - but so far without any success:

    state = statespace(v0, x0, d, k, sigma, dt, n_steps)
    v = state[:, 0]
    x = state[:, 1]
    model.register_rv(v, 'v', dims=('idx_time',))
    x = pm.Deterministic('x', x, dims=('idx_time',))
    ...

Your custom dist scan can use/compute deterministics but can’t return them. So you shouldn’t stack the two, but instead return only the non-deterministic sequence. You can always recover the deterministic sequence with a second scan outside the customdist that takes initial value and the non deterministic sequence as input (which is what PyMC would always need).

Unfortunately there’s no machinery to derive the deterministic function from the measurable one automatically so you have to figure it out.

For instance, say you deterministically carry around the two step difference when generating the random sequence, something like:

def dist(x0, ...):
  def step(x_tm1, diff_tm1):
    x_t = pm.Normal.dist((x_tm1 + diff_tm1) / 2)
    diff_t = x_t - x_tm1
    return (x_t, diff_t), update_t...
  
  (x_seq, diff_seq), updates = scan(...)
  retun x_seq  # don't return diff_t

with pm.Model() as m:
  x0 = ...
  x_seq = pm.CustomDist(”x_seq”, x0, dist=dist, ...)
  # recompute diff_t for access later
  diff_t = pm.Deterministic("diff_t", pt.diff(x_seq))

Your custom dist can return (and sample) x sequences. If you want access to the underlying deterministic diffs you have to recompute them either with a determinsitic scan or in this case a simpler deterministic operation.

The problem with your solution is that during inverse mcmc sampling you would need PyMC to figure out how to make x a deterministic function of the sampled v which is not something that’s encoded in your generative scan which builds x and v simultaneously. PyMC could in principle derive it (just like it derives the logp of v for inverse sampling), but there’s no easy way to achieve that right now.

1 Like

You’re running into all the snags that caused me to work on the statespace module!

If the actual model you’re going for is linear and guassian, you can implement a custom statespace class to handle arbitrary models. Here’s the tutorial notebook on custom classes. Plus you get all the post estimation stuff for free! I believe the pendulum example falls neatly within the available framework, for example.

If your model is non-linear, you can always linearize it around the steady state (assuming its stationary – if it’s not, you can transform it to be stationary). If it’s non-guassian, you can often introduce latent states to e.g. introduce auto-correlation in the innovations.

Worst case, you can open a PR for an extended or unscented kalman filter :slight_smile:

Knowing nothing about the actual problem you want to solve I can’t give more specific advice. But I recommend you give the statespace classes a look.

1 Like

Thanks for your very fast feedback, @ricardoV94 & @jessegrabowski !

Knowing nothing about the actual problem you want to solve I can’t give more specific advice.

As said in my first post, my goal is to model several correlated discrete time series. These are measurements related to some underlying states, which I want to determine and possibly also want to extrapolate into the future. So far, I also don’t know yet what exact model I should implement for this problem. However, for me, it would be easiest to model these time series as a non-linear state space, since it would fit best with my personal intuition about the problem and therefore would allow me to try out several different dynamics - successively from very simple to more complex ones.

The pendulum model is however totally made up, just to have a self-contained example for this thread to show some typical structures. But I would agree that this should (almost) fit into the PyMCStateSpace framework. :wink:

If your model is non-linear, you can always linearize it around the steady state (assuming its stationary – if it’s not, you can transform it to be stationary). If it’s non-guassian, you can often introduce latent states to e.g. introduce auto-correlation in the innovations.

I actually expect that the dynamics of the model will have some switching behavior (e.g. saturation or bang-bang dynamics), which cannot sufficiently be approximated by linearization. My rational to consider also non-gaussian noise is that the domain of some variables is bound, e.g. to non-negative reals only. I’m also not yet sure if the introduction of few discrete states would make sense - but for now I think it’s reasonable to consider it as a continuous problem.

Given this context, I think @ricardoV94’s suggestion would be the best way forward - to register only the random states as CustomDist and to recover the deterministic states afterwards. It is maybe not the most convenient structure for trying out many different dynamics, but likely the best I can get for now. :slight_smile: I’ve adapted the pendulum model accordingly:

def step(v, x, d, k, sigma, dt):
    v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
    x_next = v * dt + x
    return [v_next, x_next], pm.pytensorf.collect_default_updates(outputs=[v_next, x_next], inputs=[v, x, d, k, sigma, dt])

def step_d(v, x, d, k, sigma, dt):
    x_next = v * dt + x
    return [x_next], pm.pytensorf.collect_default_updates(outputs=[x_next], inputs=[v, x, d, k, sigma, dt])

def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
    [v, x], updates = pytensor.scan(
        fn=step,
        sequences=[],
        outputs_info=[v0, x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return v

def statespace_d(v, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
    x, updates = pytensor.scan(
        fn=step_d,
        sequences=[v],
        outputs_info=[x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return x

with pm.Model() as model:
    time = pd.RangeIndex(0, 200)
    dt = time.step
    n_steps = pm.Data('n_steps', time.size).astype('int')  # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data

    model.add_coord('idx_time', time)
    model.add_coord('states', ['v', 'x'])

    k = pm.Gamma('k', mu=0.02, sigma=0.01)
    d = pm.Gamma('d', mu=0.4, sigma=0.2)
    sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
    x0 = pm.Normal('x0', mu=0, sigma=0.5)
    v0 = pm.Normal('v0', mu=0, sigma=0.5)
    eps = 0.01

    v = pm.CustomDist('v', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time',))
    x = pm.Deterministic('x', statespace_d(v, x0, d, k, sigma, dt, n_steps), dims=('idx_time',))
    state = pm.Deterministic('state', pt.stack([v,x]).T, dims=('idx_time', 'states'))  # Just to keep same variable name for postprocessing

    x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))

with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
    idata_generative = pm.sample_prior_predictive()
    data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values

with pm.observe(model, {x_obs: data_x}) as model_observed:
    idata = pm.sample_prior_predictive()

However, when I sample from this model with numpyro, I get the following error after all samples have been drawn:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[30], line 2
      1 with model_observed:
----> 2     idata.extend(pm.sample(nuts_sampler='numpyro'))
      3     # idata.extend(pm.sample(nuts_sampler='blackjax'))
      4     # idata.extend(pm.sample(nuts_sampler='nutpie'))
      5     # idata.extend(pm.sample(nuts_sampler='pymc'))
      6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    720         raise ValueError(
    721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    722         )
    724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
    726             sampler=nuts_sampler,
    727             draws=draws,
    728             tune=tune,
    729             chains=chains,
    730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    731             random_seed=random_seed,
    732             initvals=initvals,
    733             model=model,
    734             var_names=var_names,
    735             progressbar=progressbar,
    736             idata_kwargs=idata_kwargs,
    737             compute_convergence_checks=compute_convergence_checks,
    738             nuts_sampler_kwargs=nuts_sampler_kwargs,
    739             **kwargs,
    740         )
    742 if isinstance(step, list):
    743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    353 elif sampler in ("numpyro", "blackjax"):
    354     import pymc.sampling.jax as pymc_jax
--> 356     idata = pymc_jax.sample_jax_nuts(
    357         draws=draws,
    358         tune=tune,
    359         chains=chains,
    360         target_accept=target_accept,
    361         random_seed=random_seed,
    362         initvals=initvals,
    363         model=model,
    364         var_names=var_names,
    365         progressbar=progressbar,
    366         nuts_sampler=sampler,
    367         idata_kwargs=idata_kwargs,
    368         compute_convergence_checks=compute_convergence_checks,
    369         **nuts_sampler_kwargs,
    370     )
    371     return idata
    373 else:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    652     log_likelihood = None
    654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
    656     jax_fn,
    657     raw_mcmc_samples,
    658     postprocessing_backend=postprocessing_backend,
    659     postprocessing_vectorize=postprocessing_vectorize,
    660     donate_samples=True,
    661 )
    662 del raw_mcmc_samples
    663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
    202     def process_fn(x):
    203         return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205     return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
    207 else:
    208     raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

    [... skipping hidden 11 frame]

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
    202 def process_fn(x):
--> 203     return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

    [... skipping hidden 6 frame]

File /tmp/tmppi3fv53j:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
     31 scalar_variable = scalar_from_tensor(tensor_variable_10)
     32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
     34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
     35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)

File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py:45, in jax_funcify_Subtensor.<locals>.subtensor(x, *ilists)
     42 if len(indices) == 1:
     43     indices = indices[0]
---> 45 return x.__getitem__(indices)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
    738 def op(self, *args):
--> 739   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
    351 def _getitem(self, item):
--> 352   return lax_numpy._rewriting_take(self, item)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   6591       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
   6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6595                unique_indices, mode, fill_value)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6601             unique_indices, mode, fill_value):
   6602   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   6604   y = arr
   6606   if fill_value is not None:

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
   6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
   6838            for elt in (i.start, i.stop, i.step)):
   6839   msg = ("Array slice indices must have static start/stop/step to be used "
   6840          "with NumPy indexing syntax. "
   6841          f"Found slice({i.start}, {i.stop}, {i.step}). "
   (...)
   6844          "dynamic_update_slice (JAX does not support dynamically sized "
   6845          "arrays within JIT compiled functions).")
-> 6846   raise IndexError(msg)
   6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
   6849 slice_shape.append(slice_size)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

With blackjax & nutpie I get similar error messages before sampling has started. The pymc NUTS is unfortunately too slow to give me some results in reasonable time.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[34], line 3
      1 with model_observed:
      2     # idata.extend(pm.sample(nuts_sampler='numpyro'))
----> 3     idata.extend(pm.sample(nuts_sampler='blackjax'))
      4     # idata.extend(pm.sample(nuts_sampler='nutpie'))
      5     # idata.extend(pm.sample(nuts_sampler='pymc'))
      6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    720         raise ValueError(
    721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    722         )
    724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
    726             sampler=nuts_sampler,
    727             draws=draws,
    728             tune=tune,
    729             chains=chains,
    730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    731             random_seed=random_seed,
    732             initvals=initvals,
    733             model=model,
    734             var_names=var_names,
    735             progressbar=progressbar,
    736             idata_kwargs=idata_kwargs,
    737             compute_convergence_checks=compute_convergence_checks,
    738             nuts_sampler_kwargs=nuts_sampler_kwargs,
    739             **kwargs,
    740         )
    742 if isinstance(step, list):
    743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:356, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    353 elif sampler in ("numpyro", "blackjax"):
    354     import pymc.sampling.jax as pymc_jax
--> 356     idata = pymc_jax.sample_jax_nuts(
    357         draws=draws,
    358         tune=tune,
    359         chains=chains,
    360         target_accept=target_accept,
    361         random_seed=random_seed,
    362         initvals=initvals,
    363         model=model,
    364         var_names=var_names,
    365         progressbar=progressbar,
    366         nuts_sampler=sampler,
    367         idata_kwargs=idata_kwargs,
    368         compute_convergence_checks=compute_convergence_checks,
    369         **nuts_sampler_kwargs,
    370     )
    371     return idata
    373 else:

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:655, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    652     log_likelihood = None
    654 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
--> 655 result = _postprocess_samples(
    656     jax_fn,
    657     raw_mcmc_samples,
    658     postprocessing_backend=postprocessing_backend,
    659     postprocessing_vectorize=postprocessing_vectorize,
    660     donate_samples=True,
    661 )
    662 del raw_mcmc_samples
    663 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:205, in _postprocess_samples(jax_fn, raw_mcmc_samples, postprocessing_backend, postprocessing_vectorize, donate_samples)
    202     def process_fn(x):
    203         return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))
--> 205     return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples)
    207 else:
    208     raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}")

    [... skipping hidden 11 frame]

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/jax.py:203, in _postprocess_samples.<locals>.process_fn(x)
    202 def process_fn(x):
--> 203     return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend))

    [... skipping hidden 6 frame]

File /tmp/tmpx0adh6st:33, in jax_funcified_fgraph(k_log_, d_log_, sigma_log_, x0, v0, v)
     31 scalar_variable = scalar_from_tensor(tensor_variable_10)
     32 # Subtensor{start:stop:step}(v, 0, ScalarFromTensor.0, 1)
---> 33 tensor_variable_11 = subtensor(v, scalar_constant_1, scalar_variable, scalar_constant_2)
     34 # Scan{statespace, while_loop=False, inplace=none}(200, Subtensor{start:stop:step}.0, SetSubtensor{:stop}.0)
     35 tensor_variable_12 = scan(tensor_constant_1, tensor_variable_11, tensor_variable_4)

File /opt/conda/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py:45, in jax_funcify_Subtensor.<locals>.subtensor(x, *ilists)
     42 if len(indices) == 1:
     43     indices = indices[0]
---> 45 return x.__getitem__(indices)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
    738 def op(self, *args):
--> 739   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
    351 def _getitem(self, item):
--> 352   return lax_numpy._rewriting_take(self, item)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6594, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   6591       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
   6593 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 6594 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6595                unique_indices, mode, fill_value)

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6603, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   6600 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   6601             unique_indices, mode, fill_value):
   6602   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 6603   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   6604   y = arr
   6606   if fill_value is not None:

File /opt/conda/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:6846, in _index_to_gather(x_shape, idx, normalize_indices)
   6837 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
   6838            for elt in (i.start, i.stop, i.step)):
   6839   msg = ("Array slice indices must have static start/stop/step to be used "
   6840          "with NumPy indexing syntax. "
   6841          f"Found slice({i.start}, {i.stop}, {i.step}). "
   (...)
   6844          "dynamic_update_slice (JAX does not support dynamically sized "
   6845          "arrays within JIT compiled functions).")
-> 6846   raise IndexError(msg)
   6848 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
   6849 slice_shape.append(slice_size)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(0, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 1). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

Any idea how this could be fixed? Right now, I assume it is a bug within the samplers. I’ve got the following versions installed:

pymc                          5.16.2
pytensor                      2.25.2
arviz                         0.19.0
numpyro                       0.15.1
blackjax                      1.2.2
nutpie                        0.13.2
jax                           0.4.30
numpy                         1.26.4

By the way, is there any method to programmatically check if a pytensor variable is random or deterministic, such as assert x.is_deterministic() and assert v.is_random()?

You’re running into all the snags that caused me to work on the statespace module!

Yeah, while looking through several other threads and notebooks, I already got your motivation for this… :wink:

JAX is very finicky about shapes, so you can try to use freeze_data_and_dims before attempting to sample with numpyro:

https://www.pymc.io/projects/docs/en/v5.13.1/api/model/generated/pymc.model.transform.optimization.freeze_dims_and_data.html

For nutpie that’s more surprising, since it uses the NUMBA backend, could you post the traceback?

Thanks for your hint to freeze_data_and_dims! Unfortunately, for numpyro or blackjax I’m still getting the same error messages as above.

However, nutpie is now running when I use freeze_data_and_dims! :+1: The results are looking not so good - but this could also be caused by the model structure, which I could maybe further improve.

Without freeze_data_and_dims I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[33], line 4
      1 with model_observed:
      2     # idata.extend(pm.sample(nuts_sampler='numpyro'))
      3     # idata.extend(pm.sample(nuts_sampler='blackjax'))
----> 4     idata.extend(pm.sample(nuts_sampler='nutpie'))
      5     # idata.extend(pm.sample(nuts_sampler='pymc'))
      6 idata

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:725, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    720         raise ValueError(
    721             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    722         )
    724     with joined_blas_limiter():
--> 725         return _sample_external_nuts(
    726             sampler=nuts_sampler,
    727             draws=draws,
    728             tune=tune,
    729             chains=chains,
    730             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    731             random_seed=random_seed,
    732             initvals=initvals,
    733             model=model,
    734             var_names=var_names,
    735             progressbar=progressbar,
    736             idata_kwargs=idata_kwargs,
    737             compute_convergence_checks=compute_convergence_checks,
    738             nuts_sampler_kwargs=nuts_sampler_kwargs,
    739             **kwargs,
    740         )
    742 if isinstance(step, list):
    743     step = CompoundStep(step)

File /opt/conda/lib/python3.12/site-packages/pymc/sampling/mcmc.py:307, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    302 if var_names is not None:
    303     warnings.warn(
    304         "`var_names` are currently ignored by the nutpie sampler",
    305         UserWarning,
    306     )
--> 307 compiled_model = nutpie.compile_pymc_model(model)
    308 t_start = time.time()
    309 idata = nutpie.sample(
    310     compiled_model,
    311     draws=draws,
   (...)
    317     **nuts_sampler_kwargs,
    318 )

File /opt/conda/lib/python3.12/site-packages/nutpie/compile_pymc.py:391, in compile_pymc_model(model, backend, gradient_backend, **kwargs)
    388     backend = "numba"
    390 if backend.lower() == "numba":
--> 391     return _compile_pymc_model_numba(model, **kwargs)
    392 elif backend.lower() == "jax":
    393     return _compile_pymc_model_jax(
    394         model, gradient_backend=gradient_backend, **kwargs
    395     )

File /opt/conda/lib/python3.12/site-packages/nutpie/compile_pymc.py:207, in _compile_pymc_model_numba(model, **kwargs)
    200 with warnings.catch_warnings():
    201     warnings.filterwarnings(
    202         "ignore",
    203         message="Cannot cache compiled function .* as it uses dynamic globals",
    204         category=numba.NumbaWarning,
    205     )
--> 207     logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
    209 expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]
    210 expand_numba_raw, c_sig_expand = _make_c_expand_func(
    211     n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data
    212 )

File /opt/conda/lib/python3.12/site-packages/numba/core/decorators.py:275, in cfunc.<locals>.wrapper(func)
    273 if cache:
    274     res.enable_caching()
--> 275 res.compile()
    276 return res

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/ccallback.py:68, in CFunc.compile(self)
     65 cres = self._cache.load_overload(self._sig,
     66                                  self._targetdescr.target_context)
     67 if cres is None:
---> 68     cres = self._compile_uncached()
     69     self._cache.save_overload(self._sig, cres)
     70 else:

File /opt/conda/lib/python3.12/site-packages/numba/core/ccallback.py:82, in CFunc._compile_uncached(self)
     79 sig = self._sig
     81 # Compile native function as well as cfunc wrapper
---> 82 return self._compiler.compile(sig.args, sig.return_type)

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
     79 def compile(self, args, return_type):
---> 80     status, retval = self._compile_cached(args, return_type)
     81     if status:
     82         return retval

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File /opt/conda/lib/python3.12/site-packages/numba/core/dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721 
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:481, in CompilerBase._compile_core(self)
    478 except Exception as e:
    479     if (utils.use_new_style_errors() and not
    480             isinstance(e, errors.NumbaError)):
--> 481         raise e
    483     self.state.status.fail_reason = e
    484     if is_final_pipeline:

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:364, in PassManager.run(self, state)
    361 except Exception as e:
    362     if (utils.use_new_style_errors() and not
    363             isinstance(e, errors.NumbaError)):
--> 364         raise e
    365     msg = "Failed in %s mode pipeline (step: %s)" % \
    366         (self.pipeline_name, pass_desc)
    367     patched_exception = self._patch_error(msg, e)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File /opt/conda/lib/python3.12/site-packages/numba/core/untyped_passes.py:1731, in LiteralUnroll.run_pass(self, state)
   1729 pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
   1730 pm.finalize()
-> 1731 pm.run(state)
   1732 return True

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:364, in PassManager.run(self, state)
    361 except Exception as e:
    362     if (utils.use_new_style_errors() and not
    363             isinstance(e, errors.NumbaError)):
--> 364         raise e
    365     msg = "Failed in %s mode pipeline (step: %s)" % \
    366         (self.pipeline_name, pass_desc)
    367     patched_exception = self._patch_error(msg, e)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File /opt/conda/lib/python3.12/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File /opt/conda/lib/python3.12/site-packages/numba/core/typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File /opt/conda/lib/python3.12/site-packages/numba/core/typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     91     infer.build_constraint()
     92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
     94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     96 return _TypingResults(typemap, restype, calltypes, errs)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
   1080 oldtoken = newtoken
   1081 # Errors can appear when the type set is incomplete; only
   1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
   1084 newtoken = self.get_state_token()
   1085 self.debug.propagate_finished()

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:182, in ConstraintNetwork.propagate(self, typeinfer)
    180     errors.append(utils.chain_exception(new_exc, e))
    181 elif utils.use_new_style_errors():
--> 182     raise e
    183 else:
    184     msg = ("Unknown CAPTURED_ERRORS style: "
    185            f"'{config.CAPTURED_ERRORS}'.")

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
    157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
    158                                        lineno=loc.line):
    159     try:
--> 160         constraint(typeinfer)
    161     except ForceLiteralArg as e:
    162         errors.append(e)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
    581     fnty = typevars[self.func].getone()
    582 with new_error_context("resolving callee type: {0}", fnty):
--> 583     self.resolve(typeinfer, typevars, fnty)

File /opt/conda/lib/python3.12/site-packages/numba/core/typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
    604     fnty = fnty.instance_type
    605 try:
--> 606     sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
...

AttributeError: module 'numpy' has no attribute 'bool'.
`np.bool` was a deprecated alias for the builtin `bool`. To avoid this error in existing code, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
The aliases was originally deprecated in NumPy 1.20; for more details and guidance see the original release note at:
    https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations

Interesting, can you provide a minimal reproducible example? Sounds like something we should be able to patch

Sure, it is basically the example above plus some imports:

import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import pytensor

def step(v, x, d, k, sigma, dt):
    v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
    x_next = v * dt + x
    return [v_next, x_next], pm.pytensorf.collect_default_updates(outputs=[v_next, x_next], inputs=[v, x, d, k, sigma, dt])

def step_d(v, x, d, k, sigma, dt):
    x_next = v * dt + x
    return [x_next], pm.pytensorf.collect_default_updates(outputs=[x_next], inputs=[v, x, d, k, sigma, dt])

def statespace(v0, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
    [v, x], updates = pytensor.scan(
        fn=step,
        sequences=[],
        outputs_info=[v0, x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return v

def statespace_d(v, x0, d, k, sigma, dt, n_steps, *args, **kwargs):
    x, updates = pytensor.scan(
        fn=step_d,
        sequences=[v],
        outputs_info=[x0],
        non_sequences=[d, k, sigma, dt],
        n_steps=n_steps,
        name='statespace',
        strict=True,
    )
    return x

with pm.Model() as model:
    time = pd.RangeIndex(0, 200)
    dt = time.step
    n_steps = pm.Data('n_steps', time.size).astype('int')  # For whatever reason, I'm getting an AssertionError during posterior sampling if this is not casted into pm.Data

    model.add_coord('idx_time', time)
    model.add_coord('states', ['v', 'x'])

    k = pm.Gamma('k', mu=0.02, sigma=0.01)
    d = pm.Gamma('d', mu=0.4, sigma=0.2)
    sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
    x0 = pm.Normal('x0', mu=0, sigma=0.5)
    v0 = pm.Normal('v0', mu=0, sigma=0.5)
    eps = 0.01

    v = pm.CustomDist('v', v0, x0, d, k, sigma, dt, n_steps, dist=statespace, dims=('idx_time',))
    x = pm.Deterministic('x', statespace_d(v, x0, d, k, sigma, dt, n_steps), dims=('idx_time',))
    state = pm.Deterministic('state', pt.stack([v,x]).T, dims=('idx_time', 'states'))  # Just to keep same variable name for postprocessing

    x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time'))

with pm.do(model, {k: 0.02, d: 0.4, sigma: 0.02, x0: 2.4, v0: 0}) as model_generative:
    idata_generative = pm.sample_prior_predictive()
    data_x = idata_generative.prior.x_obs.sel(chain=0, draw=42).values

with pm.observe(model, {x_obs: data_x}) as model_observed:
    idata = pm.sample(nuts_sampler='nutpie')

So, last week I got some time to continue working on this topic - And finally, I was able to solve the issues mentioned above! :slight_smile:

The root problem was that n_steps provided to scan must be an integer to ensure fixed tensor size, while CustomDist allows to pass only tensor variables as dist_param. Also, pt.specify_shape helped to explicitly fix the tensor size returned by scan, so that all samplers can now work with the model. :+1:

As mentioned before, my motivation is to try out several state space models with different dynamics to fit my data. To simplify this process, I have implemented now a little StateSpace utility function, which is helping to setup the state space models:

import pymc as pm
import pytensor.tensor as pt
import pytensor
from pymc.pytensorf import collect_default_updates


def is_rnd(f):
    """Decorator to mark a step function as a random state."""
    f._is_rnd = True
    return f

def is_det(f):
    """Decorator to mark a step function as a deterministic state."""
    f._is_rnd = False
    return f

def StateSpace(states_name=[], inputs=[], states_init=[], params=[], states_step=[], n_steps=0, **kwargs):
    """Create discrete state space model from given step functions.
    
    :param states_name: Names of state variables.
    :type states_name: list[str]
    :param inputs: Time-dependent inputs.
    :type inputs: list[pt.TensorVariable]
    :param states_init: Initial value for each state.
    :type states_init: list[pt.TensorVariable]
    :param params: Time-independent parameters.
    :type params: list[pt.TensorVariable]
    :param states_step: Step function for each state, returning the value for the current timestep given the state values for the previous timestep. Step functions must have the signature: `step(inputs, states, params) -> state_new`.
    :type states_step: list[Callable]
    :param n_steps: Number of timesteps to propagate the step functions.
    :type n_steps: int
    :return: Time evolution of each state variable.
    :rtype: list[pt.TensorVariable]
    """
    assert len(states_name) == len(states_init) == len(states_step) > 0, 'states_name, states_init, and states_step must be of the same non-zero length.'
    assert isinstance(n_steps, int), 'n_steps must be integer to ensure fixed tensor size at compile time.'
    
    n_inputs = len(inputs)
    n_states = len(states_init)
    n_params = len(params)
    idx_rnd = [i for i, step in enumerate(states_step) if getattr(step, '_is_rnd', False)]
    idx_det = [i for i, step in enumerate(states_step) if not getattr(step, '_is_rnd', False)]
    idx_inverse = [(idx_rnd + idx_det).index(i) for i in range(n_states)]
    n_states_rnd = len(idx_rnd)
    n_states_det = len(idx_det)
    states = n_states * [None]
    
    def fn_rnd(*args):
        inputs = args[:n_inputs]
        states_rnd = args[n_inputs]
        states_rnd = [states_rnd[i] for i in range(n_states_rnd)]
        states_det = args[n_inputs+1:n_inputs+1+n_states_det]
        states = [*states_rnd, *states_det]
        states = [states[j] for j in idx_inverse]  # Restore original order of states
        params = args[n_inputs+1+n_states_det:]
        out_raw = [step(inputs, states, params) for step in states_step]  # Eval step functions
        out = [pt.stack([out_raw[i] for i in idx_rnd], axis=-1)] + [out_raw[i] for i in idx_det]  # Random states first & concatenated, deterministic states last
        return out, pm.pytensorf.collect_default_updates(outputs=out, inputs=args)
    
    def dist_ss_rnd(*args, **kwargs):
        inputs = args[:n_inputs]
        states_init = args[n_inputs:n_inputs+n_states]
        params = args[n_inputs+n_states:n_inputs+n_states+n_params]
        
        out, updates = pytensor.scan(
            fn=fn_rnd,
            sequences=inputs,
            outputs_info=[pt.stack([states_init[i] for i in idx_rnd])] + [states_init[i] for i in idx_det],  # Random states first & concatenated, deterministic states last
            non_sequences=params,
            n_steps=n_steps,
            strict=True,
            return_list=True,
        )
        return out[0]
            
    states_rnd = pm.CustomDist('StateSpace_' + '_'.join([states_name[j] for j in idx_rnd]), *inputs, *states_init, *params, n_steps, dist=dist_ss_rnd) # TODO: dims
    states_rnd = pt.specify_shape(states_rnd, (n_steps, n_states_rnd))
    for i, j in enumerate(idx_rnd):
        states[j] = pm.Deterministic(states_name[j], states_rnd[:, i], **kwargs)
    states_rnd = [s for s in states if s is not None]
            
    def fn_det(*args):
        inputs = args[:n_inputs]
        states = args[n_inputs:n_inputs+n_states]
        states = [states[j] for j in idx_inverse]  # Restore original order of states
        params = args[n_inputs+n_states:]
        out = [step(inputs, states, params) for step in states_step if not getattr(step, '_is_rnd', False)]
        return out, pm.pytensorf.collect_default_updates(outputs=out, inputs=args)
    
    out, updates = pytensor.scan(
        fn=fn_det,
        sequences=inputs + states_rnd,  # Random states given as inputs
        outputs_info=[states_init[i] for i in idx_det],
        non_sequences=params,
        n_steps=n_steps,
        strict=True,
        return_list=True,
    )

    for i, j in enumerate(idx_det):
        states[j] = pm.Deterministic(states_name[j], out[i], **kwargs)
    assert all(s is not None for s in states), 'All states shall be initialized'

    return states

With this function, the pendulum model from above can be simplified to:

@is_rnd
def step_v(inputs, states, params):
    x, v = states
    d, k, sigma, dt = params
    v_next = -pm.Normal.dist(mu=d * dt * v * pt.abs(v), sigma=sigma) - k * dt * pt.sin(x) + v
    return v_next

@is_det
def step_x(inputs, states, params):
    x, v = states
    d, k, sigma, dt = params
    x_next = v * dt + x
    return x_next

with pm.Model() as model:
    time = pd.RangeIndex(0, 200)
    dt = time.step
    n_steps = time.size

    model.add_coord('idx_time', time)

    k = pm.Gamma('k', mu=0.02, sigma=0.01)
    d = pm.Gamma('d', mu=0.4, sigma=0.2)
    sigma = pm.Gamma('sigma', mu=0.02, sigma=0.01)
    x0 = pm.Normal('x0', mu=0, sigma=0.5)
    v0 = pm.Normal('v0', mu=0, sigma=0.5)
    eps = 0.01
    
    x, v = StateSpace(
        states_name=['x', 'v'],
        states_init=[x0, v0],
        params=[d, k, sigma, dt],
        states_step=[step_x, step_v],
        n_steps=n_steps,
        dims=('idx_time',),
    )

    x_obs = pm.Normal('x_obs', mu=x, sigma=eps, dims=('idx_time',))

It’s not yet perfect, but it’s serving my needs quite well. :wink:

However, what is still annoying is the need to explicitly define the random and deterministic states via the decorators @is_rnd and @is_det. To me, this was just an intermediate solution since I got really no idea by which criteria random and deterministic variables can be distinguished reliably. Does anyone know how to check programmatically if a pytensor variable is random or deterministic?

2 Likes

The deterministic expression won’t have any default updates so you could use that.

Scan number of steps doesn’t need to be a python integer, it can be a PyTensor scalar integer (constant or otherwise).

PS: that looks pretty neat

Why do you have a second normal noise around mu besides the one already in the scan?

Setting scan steps inside a CustomDist dynamically (via shape, for example) is a pain-point I’ve hit in the past. It’s why we define the sequence length outside of the distribution in the structural forecasting example, for example.

Agreed that this is extremely promising!

I can only think the function has to handle NoneConst shape that PyMC calls the distribution with first to find out the support ndim when initializing a variable (unless support_ndim / signature is provided in advance)

Yes, this was exactly the problem, since n_steps can’t be None

from pytensor import scan

import pymc as pm
from pymc.pytensorf import collect_default_updates
from pymc.distributions.shape_utils import rv_size_is_none

def dist(shape):
    if rv_size_is_none(shape):
        n_steps = 0
    else:
        n_steps = shape[-1]
        
    def step(xtm1):
        x = pm.Normal.dist(xtm1)
        return x, collect_default_updates(x)
        
    seq, _ = scan(
        step,
        outputs_info=[pm.math.zeros(())],
        n_steps=n_steps
    )
    return seq
    

n_steps = pm.Poisson.dist(mu=10)
grw = pm.CustomDist.dist(dist=dist, shape=(n_steps,))
print(pm.draw(grw).shape, pm.draw(grw).shape)  # (8,) (12,)

Although arguably the nsteps should be a parameter for a timeseries, and not implied by the shape.

It makes most sense for it to be implied by data (if available). If there’s no observation it should probably be set by a parameter, sure

Thanks for your fast replies and the hints! :+1:

Regarding the shape issue, I had previously no problems to perform prior sampling - and therefore likely also not with pm.draw (not sure if explicitly tested). But as documented above, the samplers (especially numpyro & blackjax) were always running into IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. [...]. So @ricardoV94, I think your example above is not fully representative on this point, as you are not calling a sampler.

The deterministic expression won’t have any default updates so you could use that.

I’m not 100% sure what you mean by “default updates” in this context. Could you maybe provide me a short code snipped for explanation?

I’m not 100% sure what you mean by “default updates” in this context. Could you maybe provide me a short code snipped for explanation?
[/quote]

When you call pm.pytensorf.collect_default_updates(inputs, outputs) for a deterministic expression it will come back empty.

I don’t quite grasp why you are distinguishing / reordering them though so I may be missing the bigger point.