State Space Model with Random & Deterministic Dynamics

This week, I was working again on this problem to address some limitations in above StateSpace implementation that I encountered during usage. However, I eventually got stuck - due to some limitations which are, to my understanding, currently inherent to PyMC:

  1. CustomDist is only supporting dist functions, which are returning one random variable, not a list/tuple of multiple random variables. To circumvent this limitation, you currently have to stack all variables together before returning them from dist. This is very cumbersome if you are dealing with variables of different shape (e.g. times series of shape (n_steps,) & (3,3,n_steps)), and may also introduce further problems (please see 2. below). I assume there is not strict argument why CustomDist should not support dist functions with multiple return arguments.
  2. It seems like several operations are not supported inside the dist function, especially in combination with scan. E.g. simple operations like x = pt.specify_shape(x, (n_steps,)) or x = pt.stack([x0, x1]) on scan outputs caused RuntimeError: The logprob terms of the following value variables could not be derived: {x}. This seems to be related to issue #6351. Since the JAX samplers require fixed shapes, this is a quite significant limitation - and I don’t see a reason why pt.specify_shape should effect the logprob terms.

Well, I don’t want to call this a feature request since it’s a open source project :wink: , but without resolving one or both limitations it would be difficult for me to move on with this problem. I already had a quick look at the CustomDist code, but this is too far in the internals of PyMC to give it a try myself.

P.S.: I also tried an alternative approach to implement a non-centered time series, so that the state-space function remains completely deterministic, and the random variables are provided as sequence inputs. In theory this was working, and it would have provided a very nice model interface - but the samplers showed very strong convergence problems. So, for my problem I assume that I need to stick to the centred parametrization…