You cannot create PyMC variables inside the scan step function. You can however, register the whole scan sequence as an RV itself, manually:
import numpy as np
import aesara
import pymc as pm
k = 10
with pm.Model() as markov_chain:
transition_probs = pm.Uniform('transition_probs', lower=0, upper=1, shape = 2)
initial_state = pm.Bernoulli('initial_state', p = 0.5)
def transition(previous_state, transition_probs, old_rng):
p = transition_probs[previous_state]
next_rng, next_state = pm.Bernoulli.dist(p = p, rng=old_rng).owner.outputs
return next_state, {old_rng: next_rng}
rng = aesara.shared(np.random.default_rng())
output, updates = aesara.scan(fn=transition,
outputs_info=dict(initial = initial_state),
non_sequences=[transition_probs, rng],
n_steps=k)
assert updates
markov_chain.register_rv(output, name="p_chain")
with markov_chain:
trace = pm.sample_prior_predictive(1000, compile_kwargs=dict(updates=updates))
Don’t forget to specify updates and pass them to the sampling function, or the scan won’t be seeded properly across draws.
Unfortunately RandomVariables are a bit messy with scan