Adding Conditional Transition Probabilities to a Hidden Markov Model

Yes, you can. The problem is just some internal book-keeping. The discrete markov chain distribution is just a recursive function (scan in the pytensor lingo) over states. The current state indexes the transition probability matrix, and that row is used to parameterize a Categorical distribution which draws the next state. Variables to scan are classified by the role they play at each time step:

  • sequence :variables that should be iterated over
  • non_sequence: variables that should be provided as-is each step
  • outputs_info: variables that should be recursively fed back into the function

In DiscreteMarkovChain, P is treated like a non_sequence. You can write you own CustomDist that treats it like a sequence instead:

import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc.pytensorf import collect_default_updates

n_timesteps = 20
coords={'time':np.arange(n_timesteps), 
        'state':[0, 1], 
        'next_state':[0, 1]}

with pm.Model(coords=coords) as m:    
    def step(P, x_tm1):
        x_t = pm.Categorical.dist(p=P[x_tm1])
        return x_t, collect_default_updates(x_t)
    
    def markov_chain(x0, Ps, shape=None):
        states, _ = pytensor.scan(step,
                                  outputs_info=[x0],
                                  sequences=[Ps])
        
        return states
    
    
    P = pm.Dirichlet('P', a=[1, 1], dims=['state', 'next_state'])
    Ps = pt.stack([P] * n_timesteps) # We need as many transition matrices as time steps now
     
    # Deterministically switch to state 0 the 10th step
    go_to_zero = pt.as_tensor([[1.0, 0.0], [1.0, 0.0]])
    Ps = pm.Deterministic('Ps', Ps[10].set(go_to_zero), dims=['time', 'state', 'next_state'])
    
    x0 = pm.Bernoulli('x0', p=0.5)
    obs = pm.CustomDist('obs', x0, Ps, dist=markov_chain, dims='time')
    prior = pm.sample_prior_predictive()

There’s a lot of boilerplate here, but there’s a lot of examples and discussions of both scan and CustomDist on this discourse, I recommend searching for those terms. There is also this tutorial that goes into some detail on these topics.

Anyway, we can see that the prior deterministically goes to state 0 at time 10, as requested:

prior.prior.obs.sel(chain=0).plot.line(x='time', hue='draw', add_legend=False);

Use one of the draws from the prior as observed data to do a parameter recovery exercise:

draw_to_recover = prior.prior.sel(chain=0, draw=1)
param_values = np.concatenate([draw_to_recover[var].values.ravel() for var in ['x0', 'P']]).tolist()
data = draw_to_recover.obs.values

with pm.observe(m, {'obs':data}):
    idata = pm.sample()

az.plot_trace(idata, var_names=['~Ps'])

Sampling goes well:

And we do a reasonable job recovering the transition probabilities. We beef it on the initial state, but that’s not super surprising to me – initial conditions in time series models tend to be nuisance parameters that are only very weakly informed by the data.

az.plot_posterior(idata, var_names=['~Ps'], ref_val=param_values)