Hidden Markov Model with aesara

Hello, this is a follow up question based on a question I asked yesterday here.

I have simulated a hidden markov model where there is an underlying latent variable Z which affects a poisson rate of the observed variable.

I have made the rates very different, so it should be very easy to identify which state Z is in. I have been able to code up a model in pymc which uses a python loop to create a variable for every data point which does work. However, this is quite slow so I have been trying to do it with aesara which is much quicker, but I must have something wrong in my set up as it doesn’t identify the states and I get warnings about target acceptance.

import numpy as np
import aesara
import pymc as pm

## Simulate data

n = 100
probability_of_mode_1_in_next_period = [0.1, 0.8]

# start off in mode 0
Z = [0]

# run the chain forward
for i in range(n-1):
    Z = Z + [np.random.binomial(n = 1, p = probability_of_mode_1_in_next_period[Z[i]])]

# define the poisson rates for each mode
pois_rates = [1, 100]

# simulate our observed data based on this 
y = np.random.poisson(lam = np.where(Z, pois_rates[1], pois_rates[0]))


## Fit model using pymc

with pm.Model() as markov_chain:
    
    # prior on transition probability
    transition_probs = pm.Uniform('transition_probs', lower = 0, upper = 1, shape = 2)
    
    # prior on initial state
    initial_state = pm.Bernoulli('initial_state', p = 0.5)
    
    # priors on the poisson rates (forumalated as a base and an additional for mode 1)
    lambda_0 = pm.Uniform('lambda_0', lower = 0, upper = 2)
    additional_lambda = pm.Uniform('additional_lambda', lower = 50, upper = 150)
    
    # run the markov chain
    def transition(previous_state, transition_probs, old_rng):
        p = transition_probs[previous_state]
        next_rng, next_state = pm.Bernoulli.dist(p = p, rng=old_rng).owner.outputs
        return next_state, {old_rng: next_rng}

    rng = aesara.shared(np.random.default_rng())
    output, updates = aesara.scan(fn=transition,
                                  outputs_info=dict(initial = initial_state),
                                  non_sequences=[transition_probs, rng],
                                  n_steps=len(y))
    
    assert updates
    markov_chain.register_rv(output, name="p_chain", initval="prior")

    # now choose lambda based on states
    lam = pm.math.switch(aesara.tensor.eq(output, 1), lambda_0+additional_lambda, lambda_0)
    
    # liklihood
    y_model = pm.Poisson("y_model", lam, observed=y)

with markov_chain:
    trace = pm.sample(1000, step=pm.BinaryMetropolis([output]), chains = 4)

Any ideas much appreciated!