Hidden Semi-Markov Model

Hello!
I am trying to stimate the parameters of a hidden semi-Markov model (HSMM) using PyMC. When I started with hidden Markov model (HMM), I took the guide available in this link: How to wrap a JAX function for use in PyMC — PyMC example gallery

But, I want to use the same methodology to estimate the HSMM’s parameters of the duration distribution, emission distribution, and transition. In that sense, this is the likelihood of the model:

def hmm_logp(
    emission_observed,
    mu,
    sigma,
    logp_initial_state,
    logp_transition,
    lamda
):

    T = emission_observed.shape[0]
    N = logp_initial_state.shape[0]
    D_max = 80
    log_alpha = jnp.full((T,N), -np.inf)

    """Compute the marginal log-likelihood of a single HMM process."""
    # Compute log-likelihood of observed emissions for each (step x possible hidden state)
    logp_emission = jsp.stats.multivariate_normal.logpdf(
        emission_observed[:,None],
        mu,
        sigma,
    )

    logp_duration = jsp.stats.poisson.logpmf(jnp.arange(D_max), lamda[:,None])
    #logp_duration = jnp.log(DiscreteWeibull(jnp.arange(1,D_max), lamda[:,None], beta[:,None]))

    log_alpha = log_alpha.at[0].set(logp_initial_state + logp_emission[0])

    def u_t(i, d, t, B):
        """Compute u_t(i,d) as a sum of logarithm emission probabilities."""
        def scan_fn(carry, offset):
            sum_log_prob, valid_offset = carry
            # Only add to the sum if the offset is within range
            valid_offset = jnp.where(offset <= d, True, False)
            sum_log_prob = jnp.where(valid_offset, sum_log_prob + B[t - offset, i], sum_log_prob)
            return (sum_log_prob, valid_offset), None

        # Initialize the carry (sum_log_prob) to zero
        sum_log_prob, _ = jax.lax.scan(scan_fn, (0.0, True), jnp.arange(D_max))
        return sum_log_prob

    @jax.jit
    def compute_u_t(i,d,t,B):
        return u_t(i,d,t,B)
    
    def forward_step_init(carry, inputs):
       log_alpha_prev, t = carry
       logp_emission_t, = inputs
       log_alpha = log_alpha_prev

       for i in range(N):
            init_term = logp_initial_state[i] + logp_duration[i,t] + compute_u_t(i,t,t,logp_emission)[0]
            recursive_term = logsumexp(jnp.array([log_alpha_prev[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + compute_u_t(i,d,t,logp_emission)[0] for d in range(D_max) for j in range(N) if j != i]))
            log_alpha = log_alpha.at[t,i].set(logsumexp(jnp.array([init_term,recursive_term])))

       return (log_alpha, t+1), None
    
    inputs = (logp_emission[1:D_max],)
    (carry_final, _), log_alpha_scan = jax.lax.scan(forward_step_init, (log_alpha, 1), inputs)

    def forward_step(carry, inputs):
       log_alpha_prev, t = carry
       logp_emission_t, = inputs
       log_alpha = log_alpha_prev

       for i in range(N):
            log_alpha = log_alpha.at[t,i].set(logsumexp(jnp.array([log_alpha_prev[t-1-d,j] + logp_transition[j,i] + logp_duration[i,d] + compute_u_t(i,d,t,logp_emission)[0] for d in range(D_max) for j in range(N) if j != i])))

       return (log_alpha, t+1), None
    
    inputs = (logp_emission[D_max:],)
    (carry_final, _), log_alpha_scan = jax.lax.scan(forward_step, (log_alpha, D_max), inputs)
    last_alpha = carry_final[-1,:]
    log_likelihood = jsp.special.logsumexp(last_alpha)

    return log_likelihood

I use this log likelihood as input to the Bayesian process. The problem is, that this computation has some complexity and the learning process took a lot of time when the number of observations and the maximum duration are large.

I turn to any of you who have studied this type of model to see if you know of any way out of the time problem. With a Metropolis the algorithm takes 8 hours to avoid the HMC derivade calculation. Or am I doing something wrong?

I am grateful for any help.

@cluhmann Hello Christian. Do you have any idea about this? I would be grateful if you can help me.

@jessegrabowski might have some suggestions?

That hmm example is an excuse to show jax interoperability don’t get too hang on having yo write jax code. You can do it with PyTensor just fine.

Anyway, more importantly, is that you haven’t shown us your priors or given details about the observed data, to knw if there’s something obvious

Thank you!!! @ricardoV94
Here is the model with the priors:

n_states = 2
n_time_steps = states.shape[1] - 1
D_max = 25
with pm.Model() as model:
    #Priors for mu and sigma
    mu = pm.Uniform("mu", lower=[-4,2], upper=[-2,4], shape=n_states,initval=[-2.5,3.8])
    sigma = pm.HalfNormal("sigma", sigma=1, shape=n_states,initval=[1,1])

    p_transition = np.array([[0,1],[1,0]])
    logp_transition = pt.log(p_transition)

    # Prior for initial state
    p_initial_state = pm.Dirichlet('init_probs', a=np.array([500,1]),initval=[0.99,0.01])
    logp_initial_state = pt.log(p_initial_state)

    # Prior for duration distribution
    lamda = pm.Uniform('lambda', lower=[10,1], upper=[25,10], shape=n_states,initval=[15,5])

    #model.debug()

    #HMM state
    loglike = pm.Potential(
        "hmm_loglike",
        hmm_logp_op(
            observations,
            mu,
            sigma,
            logp_initial_state,
            logp_transition,
            lamda,
        ),
    )

And, the synthetic data:

a = np.array([[0,1],[1,0]])
alpha = [20,5]
mu = [-3,3]
var = [1,1]
pi = [0.9,0.1]
sequences = 200

def EDHMM_sequence(p_init,p_transition,alpha_d,mu_b,var_b,T):
    X, Y, D = [], [], []
    x = list(multinomial.rvs(1, p_init)).index(1)
    d = poisson.rvs(alpha_d[x],size=1)
    for t in range(T):
        y = norm.rvs(mu_b[x],var_b[x])
        X.append(x)
        D.append(d)
        Y.append(y)
        if d > 1:
            d = d - 1
        else:
            x = list(multinomial.rvs(1, p_transition[x])).index(1)
            d = poisson.rvs(alpha_d[x],size=1)
    return X,Y,D

observations = []
states = []
durations = []
for i in range(sequences):
    x, y, d = EDHMM_sequence(pi,a,alpha,mu,var,169)
    states.append(x)
    observations.append(y)
    durations.append(d)

observations = np.array(observations)
durations = np.array(durations)
states = np.array(states)

@ricardoV94 Did you see any possible error?

This morning, I have tried again in Google Colab. As you can see, with 3000 samples the time increase considerably. I think the main resources consumption is in the likelihood equation. But, It takes a lot of time. So, it becomes unfeasible for my purposes.