This week, I was working again on this problem to address some limitations in above StateSpace
implementation that I encountered during usage. However, I eventually got stuck - due to some limitations which are, to my understanding, currently inherent to PyMC:
CustomDist
is only supportingdist
functions, which are returning one random variable, not a list/tuple of multiple random variables. To circumvent this limitation, you currently have to stack all variables together before returning them fromdist
. This is very cumbersome if you are dealing with variables of different shape (e.g. times series of shape(n_steps,)
&(3,3,n_steps)
), and may also introduce further problems (please see 2. below). I assume there is not strict argument whyCustomDist
should not supportdist
functions with multiple return arguments.- It seems like several operations are not supported inside the
dist
function, especially in combination withscan
. E.g. simple operations likex = pt.specify_shape(x, (n_steps,))
orx = pt.stack([x0, x1])
onscan
outputs causedRuntimeError: The logprob terms of the following value variables could not be derived: {x}
. This seems to be related to issue #6351. Since the JAX samplers require fixed shapes, this is a quite significant limitation - and I don’t see a reason whypt.specify_shape
should effect the logprob terms.
Well, I don’t want to call this a feature request since it’s a open source project , but without resolving one or both limitations it would be difficult for me to move on with this problem. I already had a quick look at the
CustomDist
code, but this is too far in the internals of PyMC to give it a try myself.
P.S.: I also tried an alternative approach to implement a non-centered time series, so that the state-space function remains completely deterministic, and the random variables are provided as sequence inputs. In theory this was working, and it would have provided a very nice model interface - but the samplers showed very strong convergence problems. So, for my problem I assume that I need to stick to the centred parametrization…