Tracking public opinion over time with Dirichlet

Thanks for the suggestion to use scan. I have got it half working - the model compiles, as can be seen in this map.

with pm.Model(coords=coords) as model:

    # Initial voting intention (Dirichlet prior)
    vi_init = pm.Dirichlet("vi_init", a=np.ones(n_parties))

    def dirichlet_transition(a_prev, historic_conc):
        """Updates each step of the sequence based on history."""
        a_next = pm.Dirichlet.dist(a=a_prev * historic_conc)
        return a_next, collect_default_updates([a_next])

    def vi_dist(vi_init, historic_conc, _size):
        """Iterate over time to update voting intentions."""
        sequence, _ = pytensor.scan(
            fn=dirichlet_transition,
            outputs_info=[{"initial": vi_init, "taps": [-1]}],
            non_sequences=[historic_conc],
            n_steps=n_months - 1,
            strict=True,
        )
        return sequence

    vi_steps = pm.CustomDist(
        "vi_steps",
        vi_init,
        historic_conc,
        dist=vi_dist,
        dims=("steps", "parties"),
    )

    vi_init_2d = vi_init[None, :]  # type: ignore[unsubscriptable-object]
    vi = pm.Deterministic(
        name="vi",
        var=pt.concatenate([vi_init_2d, vi_steps], axis=0),
        dims=("months", "parties"),
    )

    # Likelihood model: observed polls follow a Dirichlet distribution
    _observed = pm.Dirichlet(
        "observed",
        a=vi[polled_in_month, :] * polling_conc,  # nth_row = x[n, :]
        observed=observed_polls,
        dims=("polls", "parties"),
    )

However, it generates a host of errors at run-time. The code snippet follows.

dirichlet.py (3.8 KB)

The error appears esoteric (to me at least). And they go on for pages. And I am not sure, which of the identified issues is actually the problem (although the scan statment seems to be at the heart of things).

Any help would be greatly appreciated.