Simple markov chain with aesara scan

Hello, I’m very new to PyMC so apologies that this is most likely a silly question. I’m trying to create a simple markov chain using aesara scan. I am getting an error I don’t really understand about missing inputs. Here is a minimal example:

import aesara
import pymc as pm

k = 10

with pm.Model() as markov_chain:
    
    transition_probs = pm.Uniform('transition_probs', lower = 0 , upper = 1, shape = 2)
    
    initial_state = pm.Bernoulli('initial_state', p = 0.5)
    
    def transition(previous_state):
        p = transition_probs[previous_state]
        return pm.Bernoulli('chain', p = p)
        
    output, updates = aesara.scan(fn=transition,
                                  outputs_info=dict(initial = initial_state),
                                  n_steps=k)
    
with markov_chain:
    trace = pm.sample_prior_predictive(100)

Thanks for any suggestions!

You should pass transition_probs as input to scan, because Aesara scan basically cannot handle closure of a graph tensor:

with pm.Model() as markov_chain:
    ...
    def transition(previous_state, transition_probs):
        p = transition_probs[previous_state]
        return pm.Bernoulli.dist(p = p)
        
    output, updates = aesara.scan(fn=transition, 
                                  outputs_info=[initial_state],
                                  non_sequences=[transition_probs],
                                  n_steps=k)
    markov_chain.register_rv(output, name="mc_chain")
1 Like

You cannot create PyMC variables inside the scan step function. You can however, register the whole scan sequence as an RV itself, manually:

import numpy as np
import aesara
import pymc as pm

k = 10

with pm.Model() as markov_chain:
    
    transition_probs = pm.Uniform('transition_probs', lower=0, upper=1, shape = 2)
    initial_state = pm.Bernoulli('initial_state', p = 0.5)
    
    def transition(previous_state, transition_probs, old_rng):
        p = transition_probs[previous_state]
        next_rng, next_state = pm.Bernoulli.dist(p = p, rng=old_rng).owner.outputs
        return next_state, {old_rng: next_rng}

    rng = aesara.shared(np.random.default_rng())
    output, updates = aesara.scan(fn=transition,
                                  outputs_info=dict(initial = initial_state),
                                  non_sequences=[transition_probs, rng],
                                  n_steps=k)
    assert updates
    markov_chain.register_rv(output, name="p_chain")

with markov_chain:
    trace = pm.sample_prior_predictive(1000, compile_kwargs=dict(updates=updates))

Don’t forget to specify updates and pass them to the sampling function, or the scan won’t be seeded properly across draws.

Unfortunately RandomVariables are a bit messy with scan :confused:

Haha @junpenglao beat me to it! Don’t forget the updates!

Oh so Aesara scan could indeed handle closure :sweat_smile: why I remember I got some error before doing that.

I think the problem appeared later with the gradient of the logp

Haha ok, so it is still better to write all the input explicitly (which make sense)

Thanks so much for your replies both! On your point about RandomVariables being messy with scan - is there another approach to a markov chain which you think would be less messy? (I have tried using normal python loops but I have read on other posts that this will be much slower)

Also, when I replace the sampling with .sample( to get the posterior (which I assume would be the same in this case but not in the more complicated cases I am hoping to build up to), I get this warning UserWarning: Moment not defined for variable and an error too.

The code needed is what is messy/ugly (subjective opinion), but the approach should be perfectly fine.

The moment is not an error, just a warning. You can set an initval when you call model.register_rv, if you want to fix the initial point to a specific value. If you are okay with a draw from the prior, you can set initval="prior", to silence the warning. If your scan variable is observed, you won’t need to, in that case you pass the observed data as data when calling model.register_rv.

For the error, I need to try it out :slight_smile:

I think the error you see may be coming from the sampler proposing values out of bounds for the bernoulli variable, like p=transition_probs[2]. Again this is only an issue if you are not conditioning on observed data, but actually sampling the scan sequence.

If you specify a better sampler manually, it seems to work:

import numpy as np
import aesara
import pymc as pm

k = 10

with pm.Model() as markov_chain:
    
    transition_probs = pm.Uniform('transition_probs', lower=0, upper=1, shape = 2)
    initial_state = pm.Bernoulli('initial_state', p = 0.5)
    
    def transition(previous_state, transition_probs, old_rng):
        p = transition_probs[previous_state]
        next_rng, next_state = pm.Bernoulli.dist(p = p, rng=old_rng).owner.outputs
        return next_state, {old_rng: next_rng}

    rng = aesara.shared(np.random.default_rng())
    mc_chain, updates = aesara.scan(fn=transition,
                                  outputs_info=dict(initial = initial_state),
                                  non_sequences=[transition_probs, rng],
                                  n_steps=k)
    assert updates
    markov_chain.register_rv(mc_chain, name="mc_chain", initval="prior")

with markov_chain:
    pm.sample(chains=1, step=pm.BinaryMetropolis([mc_chain]))

Ah that makes sense, thanks!

I’m trying to run the code example from @ricardoV94 above but getting this error:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[47], line 16
     13     return next_state, {old_rng: next_rng}
     15 rng = aesara.shared(np.random.default_rng())
---> 16 mc_chain, updates = aesara.scan(fn=transition,
     17                               outputs_info=dict(initial = initial_state),
     18                               non_sequences=[transition_probs, rng],
     19                               n_steps=k)
     20 assert updates
     21 markov_chain.register_rv(mc_chain, name="mc_chain", initval="prior")

File ~/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/aesara/scan/basic.py:464, in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
    462 for elem in wrap_into_list(non_sequences):
    463     if not isinstance(elem, Variable):
--> 464         non_seqs.append(at.as_tensor_variable(elem))
    465     else:
    466         non_seqs.append(elem)

File ~/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/aesara/tensor/__init__.py:49, in as_tensor_variable(x, name, ndim, **kwargs)
     17 def as_tensor_variable(
     18     x: TensorLike, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
     19 ) -> "TensorVariable":
     20     """Convert `x` into an equivalent `TensorVariable`.
     21 
     22     This function can be used to turn ndarrays, numbers, `ScalarType` instances,
   (...)
     47 
     48     """
---> 49     return _as_tensor_variable(x, name, ndim, **kwargs)

File ~/opt/anaconda3/envs/pymc/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/opt/anaconda3/envs/pymc/lib/python3.11/site-packages/aesara/tensor/__init__.py:56, in _as_tensor_variable(x, name, ndim, **kwargs)
     52 @singledispatch
     53 def _as_tensor_variable(
     54     x: TensorLike, name: Optional[str], ndim: Optional[int], **kwargs
     55 ) -> "TensorVariable":
---> 56     raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.")

NotImplementedError: Cannot convert transition_probs to a tensor variable.

If you are using PyMC >5, you have to replace Aesara by Pytensor.

1 Like