Adding Conditional Transition Probabilities to a Hidden Markov Model

I’m trying to use a hidden Markov model for participant reponses to a behavioural psychological task. I have a basic model working for the task (code below), but am struggling to figure out how to integrate one more important feature. In my task, for some changes in the hidden state being inferred here, the participant is given a message telling them the state has changed, while in others, they are left to infer the change from the stimuli. The key parameter I want to estimate for my participants from this model is the transition probability that would cause the responses the participants made, so it’s important that I can treat these trials with the explicit signal differently. Currently, my transition probability matrix has the form [[1-p, p], [p, 1-p]], but for the steps that are signalled by a message, I would like it to be [[0, 1], [1, 0]]. Is there a way of having different probability matrices for different steps like this using pm.DiscreteMarkovChain()?

with pm.Model() as model:
    # Define single transition probability (simplified model)
    p_switch = pm.Beta("p_switch", alpha=1.4, beta=4)

    # Define prior for emission probabilities
    p_cue_as_expected_for_condition = pm.Beta(
        "p_cue_as_expected_for_condition", alpha=7, beta=2
    )

    # Build emission probability matrix
    emission_probs = pm.math.stack(
        [
            [p_cue_as_expected_for_condition, 1 - p_cue_as_expected_for_condition],
            [1 - p_cue_as_expected_for_condition, p_cue_as_expected_for_condition],
        ]
    )

    # Create a single 2x2 transition matrix
    transition_probs = pm.math.stack(
        [[1 - p_switch, p_switch], [p_switch, 1 - p_switch]]
    )

    # Initial state probabilities using Categorical distribution
    start_state = 1 if df["cue_valid"].iloc[0] else 0
    init_probs = np.zeros(2)
    init_probs[start_state] = 0.98
    init_probs[~start_state] = 0.02
    init_dist = pm.Categorical.dist(p=init_probs)

    # Now use DiscreteMarkovChain with static transition probabilities
    hidden_states = pmx.DiscreteMarkovChain(
        "hidden_states", P=transition_probs, init_dist=init_dist, shape=T
    )

    # Convert known_states_array to a NumPy array if it's not already
    known_states_array = np.array(known_states_array)

    # Find indices where hidden states are known
    known_indices = np.where(~np.isnan(known_states_array))[0]

    # Get the known states at those indices
    known_states_values = known_states_array[known_indices].astype(int)

    # Enforce known hidden states using pm.Potential
    if len(known_indices) > 0:
        # Create a boolean array where conditions are met
        is_correct_state = pt.eq(hidden_states[known_indices], known_states_values)
        # Convert boolean array to log-probabilities (0 for True, -inf for False)
        logp = pt.sum(pt.switch(is_correct_state, -0.02, -3.91))
        # These values are the log equivalents of the probabilities of 0.98 and 0.02
        # respectively.
        pm.Potential("hidden_states_observed", logp)

    # Define observations
    emissions = pm.Categorical(
        "emissions", p=emission_probs[hidden_states], observed=observed_data
    )

We don’t currently time-varying transition probabilities in DiscreteMarkovChain, but there’s no reason why we couldn’t. Could you open an issue about it on the pymc-experimental repo?

Thanks for the reply. I’ve just raised an issue on the github repo. For the time being, is there likely to be any workaround to make the model work? I tried a few things, but it seems beyond my level of proficiency as a PyMC user. I’m happy to carry on if it seems doable, but it would also be helpful to be told to give up if it just won’t work.

Yes, you can. The problem is just some internal book-keeping. The discrete markov chain distribution is just a recursive function (scan in the pytensor lingo) over states. The current state indexes the transition probability matrix, and that row is used to parameterize a Categorical distribution which draws the next state. Variables to scan are classified by the role they play at each time step:

  • sequence :variables that should be iterated over
  • non_sequence: variables that should be provided as-is each step
  • outputs_info: variables that should be recursively fed back into the function

In DiscreteMarkovChain, P is treated like a non_sequence. You can write you own CustomDist that treats it like a sequence instead:

import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc.pytensorf import collect_default_updates

n_timesteps = 20
coords={'time':np.arange(n_timesteps), 
        'state':[0, 1], 
        'next_state':[0, 1]}

with pm.Model(coords=coords) as m:    
    def step(P, x_tm1):
        x_t = pm.Categorical.dist(p=P[x_tm1])
        return x_t, collect_default_updates(x_t)
    
    def markov_chain(x0, Ps, shape=None):
        states, _ = pytensor.scan(step,
                                  outputs_info=[x0],
                                  sequences=[Ps])
        
        return states
    
    
    P = pm.Dirichlet('P', a=[1, 1], dims=['state', 'next_state'])
    Ps = pt.stack([P] * n_timesteps) # We need as many transition matrices as time steps now
     
    # Deterministically switch to state 0 the 10th step
    go_to_zero = pt.as_tensor([[1.0, 0.0], [1.0, 0.0]])
    Ps = pm.Deterministic('Ps', Ps[10].set(go_to_zero), dims=['time', 'state', 'next_state'])
    
    x0 = pm.Bernoulli('x0', p=0.5)
    obs = pm.CustomDist('obs', x0, Ps, dist=markov_chain, dims='time')
    prior = pm.sample_prior_predictive()

There’s a lot of boilerplate here, but there’s a lot of examples and discussions of both scan and CustomDist on this discourse, I recommend searching for those terms. There is also this tutorial that goes into some detail on these topics.

Anyway, we can see that the prior deterministically goes to state 0 at time 10, as requested:

prior.prior.obs.sel(chain=0).plot.line(x='time', hue='draw', add_legend=False);

Use one of the draws from the prior as observed data to do a parameter recovery exercise:

draw_to_recover = prior.prior.sel(chain=0, draw=1)
param_values = np.concatenate([draw_to_recover[var].values.ravel() for var in ['x0', 'P']]).tolist()
data = draw_to_recover.obs.values

with pm.observe(m, {'obs':data}):
    idata = pm.sample()

az.plot_trace(idata, var_names=['~Ps'])

Sampling goes well:

And we do a reasonable job recovering the transition probabilities. We beef it on the initial state, but that’s not super surprising to me – initial conditions in time series models tend to be nuisance parameters that are only very weakly informed by the data.

az.plot_posterior(idata, var_names=['~Ps'], ref_val=param_values)

1 Like

Thank you so much! This looks like exactly what I need.

1 Like

instead of repeating the probabilities over time you could also keep the two possible values and index them based on a carried integer even_step=1-even_step