Using pytensor random number generator with scan for Simulator class

I tried also doing it with Potential and calculated everything in pytensor. It does give roughly the same results (see code below). I suppose it is roughly similar to SMC though, atleast in terms of its cost function. And because of the multinomial.dist involvement, cant use NUTs sampler (it complains about gradient for that operator not being defined).

I think what I am trying to do is unreasonable because I am generating a single trajectory. Where as the transfer matrix approach in DiscreteMarkovChain generates the whole probability distribution at each step, I am just basically generating a single sample from those distributions at each step and comparing it to my observed. Perhaps I should generate at least multiple trajectories but if the state space is large I wonder trying to mimic the effect of transfer matrix with just some trajectories is reasonable.

import pymc as pm
import numpy as np
import pytensor as pt
import arviz as az


def _calculate_P_pyt(state, f, m, nvariants):

  denom = state*f

  M = pt.tensor.stack([m for _ in range(nvariants)])
  M = pt.tensor.tile(M, (nvariants,1))/(nvariants-1)
  mask = np.eye(nvariants, dtype=bool)
  M = M[mask].set(1-m)

  return pt.tensor.sum(M*denom, axis=1)/pt.tensor.sum(denom)

def _WFM_sim_pyt(rng, f, m, initial_state, nsims, size=None):

  nvariants = initial_state.size
  N = initial_state.sum()
  rng = pt.shared(rng)

  def transition(*args):

    state,  f, m = args

    p = _calculate_P_pyt(state, f, m, nvariants)
    next_rng, next_state = pm.Multinomial.dist(N, p, rng=rng).owner.outputs

    return next_state, {rng: next_rng}

  result, updates = pt.scan(transition,
                            outputs_info=initial_state,
                            non_sequences=[f, m],
                            n_steps=nsims)
  return result


def WFM_ABC(initial_state, obs, seed=None):


  obs_props = obs/obs.sum(axis=1)[:,None]

  nsims = obs.shape[0]

  with pm.Model():

    _log10_m = pm.HalfNormal("_log10_m", 2)
    log10_m = -2-_log10_m

    f = pm.LogNormal("f", 0, 1, size=initial_state.size)

    counts = _WFM_sim_pyt(rng, f, 10**log10_m, initial_state, nsims)

    fit_props = counts/counts.sum(axis=1)[:,None]

    pm.Potential("pot", -0.5*pm.math.sum((fit_props-obs_props)**2))

    idata = pm.sample()

  return idata

seed = 0
f =  np.array([0.67786154, 1.89172801, 0.84634297, 0.63536372, 1.19762663])
m = 1e-4
rng = np.random.default_rng(0)
initial_state = np.array([100, 100, 10000, 100, 100])
nobs = 50
rng = np.random.default_rng()

obs_data=\
_WFM_sim_pyt(rng, f, m,
             initial_state, nobs).eval()

idata = WFM_ABC(initial_state, obs_data)

print(az.summary(idata))