Semi-Markov chains with PyMC

Hi all of you,

I am currently using PyMC to train an Explicit Duration Hidden Markov chain. I found out this link that help me a lot when I started: How to wrap a JAX function for use in PyMC — PyMC example gallery
I used that as a base code for developing the semi-Markov chain. However, the new likelihood equation has some specific troubles: dynamic slices, several “for” loops, and conditionals. Those things produce long time simulation when I integrate JAX function or vectorizations. I would like to show the code to calculate de likelihood using the forward algorithm:

def hmm_logp(
    emission_observed,
    mu,
    sigma,
    logp_initial_state,
    logp_transition,
    lamda,
):

    T = emission_observed.shape[0]
    N = logp_initial_state.shape[0]
    D_max = 2
    log_alpha = jnp.full((T,N), -np.inf)

    """Compute the marginal log-likelihood of a single HMM process."""
    # Compute log-likelihood of observed emissions for each (step x possible hidden state)
    logp_emission = jsp.stats.multivariate_normal.logpdf(
        emission_observed[:,None],
        mu,
        sigma,
    )

    logp_duration = jsp.stats.poisson.logpmf(jnp.arange(1,D_max), lamda[:,None])

    # Compute initial alpha values
    
    for t in range(min(D_max,T)):
      for i in range(N):
        if t == 0:
          log_alpha = log_alpha.at[t, i].set(logp_initial_state[i] + logp_duration[i, t] + logp_emission[t, i])
        elif t == 1:
          init_term = logp_initial_state[i] + logp_duration[i, t] + jnp.sum(logp_emission[:t+1, i])
          recursive_term = logsumexp(jnp.array([log_alpha[t-1,j]+logp_transition[j,i]+logp_duration[i,t-1]+logp_emission[t,i] for j in range(N) if j != i]))
          log_alpha = log_alpha.at[t, i].set(logsumexp(jnp.array([init_term, recursive_term])))
        else:
          init_term = logp_initial_state[i] + logp_duration[i, t] + jnp.sum(logp_emission[:t+1, i])
          recursive_term = logsumexp(jnp.array([log_alpha[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + jnp.sum(logp_emission[t-d:t+1,i]) for d in range(t-1) for j in range(N) if j != i]))
          log_alpha = log_alpha.at[t, i].set(logsumexp(jnp.array([init_term, recursive_term])))

    # Forward pass to propagate the probabilities through time
    for t in range(D_max, T):
      for i in range(N):
        log_alpha = log_alpha.at[t,i].set(logsumexp(jnp.array([log_alpha[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + jnp.sum(logp_emission[t-d:t+1,i]) for d in range(D_max) for j in range(N)])))

    # Compute the log-likelihood by summing over the final alpha values

    log_likelihood = jsp.special.logsumexp(log_alpha[-1,:])

    return log_likelihood

I have serious problems with the simulation time. It takes a lot of time for small values of duration periods. My question here is if I could use PyTensor and replace JAX to improve the performance of the algorithm?
Or do you have another recommendation to improve the performance?
I would appreciate any help!