Hello!
I am trying to stimate the parameters of a hidden semi-Markov model (HSMM) using PyMC. When I started with hidden Markov model (HMM), I took the guide available in this link: How to wrap a JAX function for use in PyMC — PyMC example gallery
But, I want to use the same methodology to estimate the HSMM’s parameters of the duration distribution, emission distribution, and transition. In that sense, this is the likelihood of the model:
def hmm_logp(
emission_observed,
mu,
sigma,
logp_initial_state,
logp_transition,
lamda
):
T = emission_observed.shape[0]
N = logp_initial_state.shape[0]
D_max = 80
log_alpha = jnp.full((T,N), -np.inf)
"""Compute the marginal log-likelihood of a single HMM process."""
# Compute log-likelihood of observed emissions for each (step x possible hidden state)
logp_emission = jsp.stats.multivariate_normal.logpdf(
emission_observed[:,None],
mu,
sigma,
)
logp_duration = jsp.stats.poisson.logpmf(jnp.arange(D_max), lamda[:,None])
#logp_duration = jnp.log(DiscreteWeibull(jnp.arange(1,D_max), lamda[:,None], beta[:,None]))
log_alpha = log_alpha.at[0].set(logp_initial_state + logp_emission[0])
def u_t(i, d, t, B):
"""Compute u_t(i,d) as a sum of logarithm emission probabilities."""
def scan_fn(carry, offset):
sum_log_prob, valid_offset = carry
# Only add to the sum if the offset is within range
valid_offset = jnp.where(offset <= d, True, False)
sum_log_prob = jnp.where(valid_offset, sum_log_prob + B[t - offset, i], sum_log_prob)
return (sum_log_prob, valid_offset), None
# Initialize the carry (sum_log_prob) to zero
sum_log_prob, _ = jax.lax.scan(scan_fn, (0.0, True), jnp.arange(D_max))
return sum_log_prob
@jax.jit
def compute_u_t(i,d,t,B):
return u_t(i,d,t,B)
def forward_step_init(carry, inputs):
log_alpha_prev, t = carry
logp_emission_t, = inputs
log_alpha = log_alpha_prev
for i in range(N):
init_term = logp_initial_state[i] + logp_duration[i,t] + compute_u_t(i,t,t,logp_emission)[0]
recursive_term = logsumexp(jnp.array([log_alpha_prev[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + compute_u_t(i,d,t,logp_emission)[0] for d in range(D_max) for j in range(N) if j != i]))
log_alpha = log_alpha.at[t,i].set(logsumexp(jnp.array([init_term,recursive_term])))
return (log_alpha, t+1), None
inputs = (logp_emission[1:D_max],)
(carry_final, _), log_alpha_scan = jax.lax.scan(forward_step_init, (log_alpha, 1), inputs)
def forward_step(carry, inputs):
log_alpha_prev, t = carry
logp_emission_t, = inputs
log_alpha = log_alpha_prev
for i in range(N):
log_alpha = log_alpha.at[t,i].set(logsumexp(jnp.array([log_alpha_prev[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + compute_u_t(i,d,t,logp_emission)[0] for d in range(D_max) for j in range(N) if j != i])))
return (log_alpha, t+1), None
inputs = (logp_emission[D_max:],)
(carry_final, _), log_alpha_scan = jax.lax.scan(forward_step, (log_alpha, D_max), inputs)
last_alpha = carry_final[-1,:]
log_likelihood = jsp.special.logsumexp(last_alpha)
return log_likelihood
I use this log likelihood as input to the Bayesian process. The problem is, that this computation has some complexity and the learning process took a lot of time when the number of observations and the maximum duration are large.
I turn to any of you who have studied this type of model to see if you know of any way out of the time problem. With a Metropolis the algorithm takes 8 hours to avoid the HMC derivade calculation. Or am I doing something wrong?
I am grateful for any help.