I tried also doing it with Potential and calculated everything in pytensor. It does give roughly the same results (see code below). I suppose it is roughly similar to SMC though, atleast in terms of its cost function. And because of the multinomial.dist involvement, cant use NUTs sampler (it complains about gradient for that operator not being defined).
I think what I am trying to do is unreasonable because I am generating a single trajectory. Where as the transfer matrix approach in DiscreteMarkovChain generates the whole probability distribution at each step, I am just basically generating a single sample from those distributions at each step and comparing it to my observed. Perhaps I should generate at least multiple trajectories but if the state space is large I wonder trying to mimic the effect of transfer matrix with just some trajectories is reasonable.
import pymc as pm
import numpy as np
import pytensor as pt
import arviz as az
def _calculate_P_pyt(state, f, m, nvariants):
denom = state*f
M = pt.tensor.stack([m for _ in range(nvariants)])
M = pt.tensor.tile(M, (nvariants,1))/(nvariants-1)
mask = np.eye(nvariants, dtype=bool)
M = M[mask].set(1-m)
return pt.tensor.sum(M*denom, axis=1)/pt.tensor.sum(denom)
def _WFM_sim_pyt(rng, f, m, initial_state, nsims, size=None):
nvariants = initial_state.size
N = initial_state.sum()
rng = pt.shared(rng)
def transition(*args):
state, f, m = args
p = _calculate_P_pyt(state, f, m, nvariants)
next_rng, next_state = pm.Multinomial.dist(N, p, rng=rng).owner.outputs
return next_state, {rng: next_rng}
result, updates = pt.scan(transition,
outputs_info=initial_state,
non_sequences=[f, m],
n_steps=nsims)
return result
def WFM_ABC(initial_state, obs, seed=None):
obs_props = obs/obs.sum(axis=1)[:,None]
nsims = obs.shape[0]
with pm.Model():
_log10_m = pm.HalfNormal("_log10_m", 2)
log10_m = -2-_log10_m
f = pm.LogNormal("f", 0, 1, size=initial_state.size)
counts = _WFM_sim_pyt(rng, f, 10**log10_m, initial_state, nsims)
fit_props = counts/counts.sum(axis=1)[:,None]
pm.Potential("pot", -0.5*pm.math.sum((fit_props-obs_props)**2))
idata = pm.sample()
return idata
seed = 0
f = np.array([0.67786154, 1.89172801, 0.84634297, 0.63536372, 1.19762663])
m = 1e-4
rng = np.random.default_rng(0)
initial_state = np.array([100, 100, 10000, 100, 100])
nobs = 50
rng = np.random.default_rng()
obs_data=\
_WFM_sim_pyt(rng, f, m,
initial_state, nobs).eval()
idata = WFM_ABC(initial_state, obs_data)
print(az.summary(idata))