Good afternoon!
I’m writing to see if y’all think that this dynamic occupancy model would be a good fit for DiscreteMarkovChain and marginalize in the PyMC Extras package. In this model, we detect animals at occupied, z=1, sites (aka patches) with probability p. We assume that animals can’t be detected at unoccupied, z=0, sites. Between seasons, unoccupied sites can be colonized with probability \gamma, and occupied sites can go extinct with probability \epsilon. Within a season, sites are surveyed and closed to colonization / extinction dynamics. Below is some simulation code, and an example of how to fit the model in NumPyro. While the NumPyro version works great, I’m curious if this would be a good fit for the goodies in pymc-extras. Personally I find PyMC syntax a little more digestible.
Thanks in advance for your help!
Phil
from jax import random
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC, Predictive
import arviz as az
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
# hyperparameters
RANDOM_SEED = 1792
## true values for colext model
PSI_TRUE = 0.6
EPSILON_TRUE = 0.3
GAMMA_TRUE = 0.15
P_TRUE = 0.4
SITE_COUNT = 250
SURVEY_COUNT = 3
SEASON_COUNT = 10
interval_count = SEASON_COUNT - 1
def simulate_data():
"""Simulate detection/non-detection data from a dynamic occupancy model"""
rng = np.random.default_rng(RANDOM_SEED)
# empty array to fill in the occupancy states later
z = np.zeros((SITE_COUNT, SEASON_COUNT), dtype=int)
# initial values for the occupancy state
z[:, 0] = rng.binomial(n=1, p=PSI_TRUE, size=SITE_COUNT)
# simulate transitions
for t in range(1, SEASON_COUNT):
# patches can be colonized, go extinct, remain occupied, or remain unoccupied
mu_z = z[:, t-1] * (1 - EPSILON_TRUE) + (1 - z[:, t-1]) * GAMMA_TRUE
z[:, t] = rng.binomial(n=1, p=mu_z)
# simulate detection non-detection data
mu_x = z * P_TRUE
x = rng.binomial(n=1, p=mu_x[:, :, None],
size=(SITE_COUNT, SEASON_COUNT, SURVEY_COUNT))
return x
def dynamic_occupancy(detection_history):
'''Dynamic occupancy model in NumPyro.'''
site_count, season_count, survey_count = detection_history.shape
# scalar priors for the four probabilistic parameters
psi = numpyro.sample("psi", dist.Uniform(0, 1)) # initial occupancy prob
gamma = numpyro.sample("gamma", dist.Uniform(0, 1)) # colonization prob
epsilon = numpyro.sample("epsilon", dist.Uniform(0, 1)) # extinction prob
p = numpyro.sample("p", dist.Uniform(0, 1)) # recapture prob
def transition_and_detect(carry, y_t):
"""Transitions betweens states and defines the likelihood."""
# unpack the values that are returned from the transition function at
# the previous time step
z_prev, t = carry
# transition the latent state at every site
with numpyro.plate("sites", site_count):
# probability of transitioning according to the previous state
mu_z_t = z_prev * (1 - epsilon) + (1 - z_prev) * gamma
# dist.util.clamp_probs() helps the sampler avoid boundary regions
z = numpyro.sample(
"z",
dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
infer={"enumerate": "parallel"}, # this is where we marginalize!
)
# the likelihood of each observation at each site
mu_y = z * p
with numpyro.plate('surveys', survey_count):
numpyro.sample(
"y",
dist.Bernoulli(dist.util.clamp_probs(mu_y)),
obs=y_t.T
)
# carry forward the current z state and incremented time index
# None indicates we don't return/accumulate any outputs from scan
return (z, t + 1), None
# the initial state only depends on psi
with numpyro.plate('sites', site_count):
z0 = numpyro.sample(
"z0",
dist.Bernoulli(dist.util.clamp_probs(psi)),
infer={"enumerate": "parallel"},
)
# compute the likelihood of the detection data for just the first season
mu_y = z0 * p
with numpyro.plate('surveys', survey_count):
numpyro.sample(
"y0",
dist.Bernoulli(dist.util.clamp_probs(mu_y)),
obs=detection_history[:, 0].T # just the first occasion!
)
# now we scan (or apply) the transition function across the remaining seasons
scan(
transition_and_detect, # function to scan
(z0, 0), # initial states
jnp.swapaxes(detection_history[:, 1:], 0, 1), # scan across first dimension of data
)
detections = simulate_data()
rng_key = random.PRNGKey(RANDOM_SEED)
# specify which sampler you want to use
nuts_kernel = NUTS(dynamic_occupancy)
# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=4)
# run the MCMC then inspect the output
mcmc.run(rng_key, detections)
mcmc.print_summary()