Not saying what you’re doing is wrong, just being pedantic that you don’t need python integers in the Scan n_steps.
No problem. However, I have tried out freeze_dims_and_databefore and this was fixing the issue only for nutpie, but not for numpyro & blackjax.
Thanks for your explanation of collect_default_updates - I understand and will give it a try!
I don’t quite grasp why you are distinguishing / reordering them though so I may be missing the bigger point.
The rational is that dist_ss_rnd must only return the random states, as it is called by CustomDist, so I need to identify and aggregate them into out[0]. Since the random states are later provided as sequences to fn_det to calculate the deterministic states, and given that scan provides sequences inputs always before outputs_info to fn, I needed to reorder the arguments to fn_det(inputs, ..., states_rnd, ..., states_det, ..., params, ...).
(Not saying, that it might be not possible to further simplify these re-ordering steps… )
This week, I was working again on this problem to address some limitations in above StateSpace implementation that I encountered during usage. However, I eventually got stuck - due to some limitations which are, to my understanding, currently inherent to PyMC:
CustomDist is only supporting dist functions, which are returning one random variable, not a list/tuple of multiple random variables. To circumvent this limitation, you currently have to stack all variables together before returning them from dist. This is very cumbersome if you are dealing with variables of different shape (e.g. times series of shape (n_steps,) & (3,3,n_steps)), and may also introduce further problems (please see 2. below). I assume there is not strict argument why CustomDist should not support dist functions with multiple return arguments.
It seems like several operations are not supported inside the dist function, especially in combination with scan. E.g. simple operations like x = pt.specify_shape(x, (n_steps,)) or x = pt.stack([x0, x1]) on scan outputs caused RuntimeError: The logprob terms of the following value variables could not be derived: {x}. This seems to be related to issue #6351. Since the JAX samplers require fixed shapes, this is a quite significant limitation - and I don’t see a reason why pt.specify_shape should effect the logprob terms.
Well, I don’t want to call this a feature request since it’s a open source project , but without resolving one or both limitations it would be difficult for me to move on with this problem. I already had a quick look at the CustomDist code, but this is too far in the internals of PyMC to give it a try myself.
P.S.: I also tried an alternative approach to implement a non-centered time series, so that the state-space function remains completely deterministic, and the random variables are provided as sequence inputs. In theory this was working, and it would have provided a very nice model interface - but the samplers showed very strong convergence problems. So, for my problem I assume that I need to stick to the centred parametrization…
It has more to do with the model restrictions, there’s no way to specify distinct dims/observed/transforms/names to the distinct RVs at once when you do pm.Foo(...) inside a model.
You can call model.register_rv with distinct rvs that come from the same node (sidestepping CustomDist altogether) but then you miss all the niceities of distributions, namely automatic resizing based on dims/observed.
I doubt the centered approach would sample any better, so you may need some modelling work regardless of the current limitations you’re noting.