Thanks for the suggestion to use scan. I have got it half working - the model compiles, as can be seen in this map.
with pm.Model(coords=coords) as model:
# Initial voting intention (Dirichlet prior)
vi_init = pm.Dirichlet("vi_init", a=np.ones(n_parties))
def dirichlet_transition(a_prev, historic_conc):
"""Updates each step of the sequence based on history."""
a_next = pm.Dirichlet.dist(a=a_prev * historic_conc)
return a_next, collect_default_updates([a_next])
def vi_dist(vi_init, historic_conc, _size):
"""Iterate over time to update voting intentions."""
sequence, _ = pytensor.scan(
fn=dirichlet_transition,
outputs_info=[{"initial": vi_init, "taps": [-1]}],
non_sequences=[historic_conc],
n_steps=n_months - 1,
strict=True,
)
return sequence
vi_steps = pm.CustomDist(
"vi_steps",
vi_init,
historic_conc,
dist=vi_dist,
dims=("steps", "parties"),
)
vi_init_2d = vi_init[None, :] # type: ignore[unsubscriptable-object]
vi = pm.Deterministic(
name="vi",
var=pt.concatenate([vi_init_2d, vi_steps], axis=0),
dims=("months", "parties"),
)
# Likelihood model: observed polls follow a Dirichlet distribution
_observed = pm.Dirichlet(
"observed",
a=vi[polled_in_month, :] * polling_conc, # nth_row = x[n, :]
observed=observed_polls,
dims=("polls", "parties"),
)
However, it generates a host of errors at run-time. The code snippet follows.
dirichlet.py (3.8 KB)
The error appears esoteric (to me at least). And they go on for pages. And I am not sure, which of the identified issues is actually the problem (although the scan statment seems to be at the heart of things).
Any help would be greatly appreciated.
