Hi all of you,
I am currently using PyMC to train an Explicit Duration Hidden Markov chain. I found out this link that help me a lot when I started: How to wrap a JAX function for use in PyMC — PyMC example gallery
I used that as a base code for developing the semi-Markov chain. However, the new likelihood equation has some specific troubles: dynamic slices, several “for” loops, and conditionals. Those things produce long time simulation when I integrate JAX function or vectorizations. I would like to show the code to calculate de likelihood using the forward algorithm:
def hmm_logp(
emission_observed,
mu,
sigma,
logp_initial_state,
logp_transition,
lamda,
):
T = emission_observed.shape[0]
N = logp_initial_state.shape[0]
D_max = 2
log_alpha = jnp.full((T,N), -np.inf)
"""Compute the marginal log-likelihood of a single HMM process."""
# Compute log-likelihood of observed emissions for each (step x possible hidden state)
logp_emission = jsp.stats.multivariate_normal.logpdf(
emission_observed[:,None],
mu,
sigma,
)
logp_duration = jsp.stats.poisson.logpmf(jnp.arange(1,D_max), lamda[:,None])
# Compute initial alpha values
for t in range(min(D_max,T)):
for i in range(N):
if t == 0:
log_alpha = log_alpha.at[t, i].set(logp_initial_state[i] + logp_duration[i, t] + logp_emission[t, i])
elif t == 1:
init_term = logp_initial_state[i] + logp_duration[i, t] + jnp.sum(logp_emission[:t+1, i])
recursive_term = logsumexp(jnp.array([log_alpha[t-1,j]+logp_transition[j,i]+logp_duration[i,t-1]+logp_emission[t,i] for j in range(N) if j != i]))
log_alpha = log_alpha.at[t, i].set(logsumexp(jnp.array([init_term, recursive_term])))
else:
init_term = logp_initial_state[i] + logp_duration[i, t] + jnp.sum(logp_emission[:t+1, i])
recursive_term = logsumexp(jnp.array([log_alpha[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + jnp.sum(logp_emission[t-d:t+1,i]) for d in range(t-1) for j in range(N) if j != i]))
log_alpha = log_alpha.at[t, i].set(logsumexp(jnp.array([init_term, recursive_term])))
# Forward pass to propagate the probabilities through time
for t in range(D_max, T):
for i in range(N):
log_alpha = log_alpha.at[t,i].set(logsumexp(jnp.array([log_alpha[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + jnp.sum(logp_emission[t-d:t+1,i]) for d in range(D_max) for j in range(N)])))
# Compute the log-likelihood by summing over the final alpha values
log_likelihood = jsp.special.logsumexp(log_alpha[-1,:])
return log_likelihood
I have serious problems with the simulation time. It takes a lot of time for small values of duration periods. My question here is if I could use PyTensor and replace JAX to improve the performance of the algorithm?
Or do you have another recommendation to improve the performance?
I would appreciate any help!