Tracking public opinion over time with Dirichlet

I am tracking public opinion over time using a Dirichlet where the public opion in any month is informed by polling and opinion in the previous month.

My question is how do I vectorise this.

model = pm.Model()
with model:
    # temporal model for monthly voting intention
    vi = []
    for i in range(P["n_months"]):
        if not i:
            a = P["start"]
        else:
            a = vi[i - 1]
        vi.append(
            pm.Dirichlet(f"vi_{i:02}", a=a * P["innovation"], shape=(P["n_parties"],))
        )
    vi_stack = pm.Deterministic("vi_stack", pm.math.stack(vi, axis=1))
    print(vi_stack.ndim)

    # likelihood / observation model
    for i in range(P["n_polls"]):
        P["y"][i] = pm.Dirichlet(
            f"y_{i:03}",
            a=vi_stack[:, P["month"][i]] * P["innovation"],
            observed=P["y"][i],
        )

Full notebook here: https://github.com/bpalmer4/Australian-Federal-Election-2025/blob/main/notebooks/_poll_agg_states_pymc.ipynb

There’s no direct way to vectorize because it’s a recursive model, but you have options.

You could parametrize the simplex differently, such as using multiple gaussian random walk and then taking the softmax at each time step (you need one less random walk than the number of dimensions)

In that case the whole series can be vectorized easily because grw = normal(shape=t).cumsum()

Otherwise if you want to stick with Dirichlets you can use a Scan. That doesn’t vectorize but keeps the symbolic graph short: Time Series Models Derived From a Generative Graph — PyMC example gallery

1 Like

The alpha parameter of the Dirichlet shouldn’t itself be a simplex, so one only needs a strictly positive random walk. There’s also the AR distribution if you don’t like the random walk assumption.

If you have features you want to include in the hidden states (e.g. ARIMAX hidden state) you have to do a scan.

I was suggesting replacing the first Dirichlet (v_t) itself by another prior on the simplex. The alpha parameter need not come from a simplex but it’s a common parametrization I’ve seen: alpha = simplex * concentration.

If I read you correctly you are suggesting dropping the first Dirichlet/ simplex (v_t) altogether and parametrize the observed Dirichlet alpha directly with a positive process? That would be my first approach as well but I don’t know if @Bryan_Palmer had a reason to start with that model.

1 Like

It’s straightforward to compute a single vector of all the a * P["innovation"] values in one go—I don’t then know whether you could unfold the Dirichlet into something vectorized in PyMC.

a = [P["start"] vi]

This has the same number of degrees of freedom (N - 1 from simplex, from concentration) and is usually a much better-behaved parameterization of Dirichlet (and the Beta, which is just a two-element Dirichlet that only returns one of the elements). There’s a discussion of this in Chapter 5 of Gelman et al.'s Bayesian Data Analysis around the first hierarchical model introduced there for rat clinical trials, where they reparameterize as a mean and concentration and discusses priors.

Thanks for the suggestion to use scan. I have got it half working - the model compiles, as can be seen in this map.

with pm.Model(coords=coords) as model:

    # Initial voting intention (Dirichlet prior)
    vi_init = pm.Dirichlet("vi_init", a=np.ones(n_parties))

    def dirichlet_transition(a_prev, historic_conc):
        """Updates each step of the sequence based on history."""
        a_next = pm.Dirichlet.dist(a=a_prev * historic_conc)
        return a_next, collect_default_updates([a_next])

    def vi_dist(vi_init, historic_conc, _size):
        """Iterate over time to update voting intentions."""
        sequence, _ = pytensor.scan(
            fn=dirichlet_transition,
            outputs_info=[{"initial": vi_init, "taps": [-1]}],
            non_sequences=[historic_conc],
            n_steps=n_months - 1,
            strict=True,
        )
        return sequence

    vi_steps = pm.CustomDist(
        "vi_steps",
        vi_init,
        historic_conc,
        dist=vi_dist,
        dims=("steps", "parties"),
    )

    vi_init_2d = vi_init[None, :]  # type: ignore[unsubscriptable-object]
    vi = pm.Deterministic(
        name="vi",
        var=pt.concatenate([vi_init_2d, vi_steps], axis=0),
        dims=("months", "parties"),
    )

    # Likelihood model: observed polls follow a Dirichlet distribution
    _observed = pm.Dirichlet(
        "observed",
        a=vi[polled_in_month, :] * polling_conc,  # nth_row = x[n, :]
        observed=observed_polls,
        dims=("polls", "parties"),
    )

However, it generates a host of errors at run-time. The code snippet follows.

dirichlet.py (3.8 KB)

The error appears esoteric (to me at least). And they go on for pages. And I am not sure, which of the identified issues is actually the problem (although the scan statment seems to be at the heart of things).

Any help would be greatly appreciated.