Yes, you can. The problem is just some internal book-keeping. The discrete markov chain distribution is just a recursive function (scan in the pytensor lingo) over states. The current state indexes the transition probability matrix, and that row is used to parameterize a Categorical distribution which draws the next state. Variables to scan are classified by the role they play at each time step:
sequence:variables that should be iterated overnon_sequence: variables that should be provided as-is each stepoutputs_info: variables that should be recursively fed back into the function
In DiscreteMarkovChain, P is treated like a non_sequence. You can write you own CustomDist that treats it like a sequence instead:
import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc.pytensorf import collect_default_updates
n_timesteps = 20
coords={'time':np.arange(n_timesteps),
'state':[0, 1],
'next_state':[0, 1]}
with pm.Model(coords=coords) as m:
def step(P, x_tm1):
x_t = pm.Categorical.dist(p=P[x_tm1])
return x_t, collect_default_updates(x_t)
def markov_chain(x0, Ps, shape=None):
states, _ = pytensor.scan(step,
outputs_info=[x0],
sequences=[Ps])
return states
P = pm.Dirichlet('P', a=[1, 1], dims=['state', 'next_state'])
Ps = pt.stack([P] * n_timesteps) # We need as many transition matrices as time steps now
# Deterministically switch to state 0 the 10th step
go_to_zero = pt.as_tensor([[1.0, 0.0], [1.0, 0.0]])
Ps = pm.Deterministic('Ps', Ps[10].set(go_to_zero), dims=['time', 'state', 'next_state'])
x0 = pm.Bernoulli('x0', p=0.5)
obs = pm.CustomDist('obs', x0, Ps, dist=markov_chain, dims='time')
prior = pm.sample_prior_predictive()
There’s a lot of boilerplate here, but there’s a lot of examples and discussions of both scan and CustomDist on this discourse, I recommend searching for those terms. There is also this tutorial that goes into some detail on these topics.
Anyway, we can see that the prior deterministically goes to state 0 at time 10, as requested:
prior.prior.obs.sel(chain=0).plot.line(x='time', hue='draw', add_legend=False);
Use one of the draws from the prior as observed data to do a parameter recovery exercise:
draw_to_recover = prior.prior.sel(chain=0, draw=1)
param_values = np.concatenate([draw_to_recover[var].values.ravel() for var in ['x0', 'P']]).tolist()
data = draw_to_recover.obs.values
with pm.observe(m, {'obs':data}):
idata = pm.sample()
az.plot_trace(idata, var_names=['~Ps'])
Sampling goes well:
And we do a reasonable job recovering the transition probabilities. We beef it on the initial state, but that’s not super surprising to me – initial conditions in time series models tend to be nuisance parameters that are only very weakly informed by the data.
az.plot_posterior(idata, var_names=['~Ps'], ref_val=param_values)


