How to marginalized Markov Chain with categorical?


#1

I have built HMM according to thishttps://github.com/hstrey/Hidden-Markov-Models-pymc3.
However, the speed is too slow for my work. The suggestion is to marginalization. I built a normal one for markov chain. It works.

# simulation of data. Two state model for simplicity.
N_seqs = 100
N_labels = 3
N_states = 2

# transition probability
P=np.array([[0.8,0.2],[0.4,0.6]])
# emission probabilityu
Pe = np.array([
    [0.8,0.1,0.1],
    [0.3,0.4,0.3]
]) 
N_labels = Pe.shape[1]

AA = np.eye(N_states) - P + np.ones(shape=(N_states,N_states))
PA = np.linalg.solve(AA.T,np.ones(shape=(N_states)))

state_seq = [np.random.choice(N_states, p=PA)]
for i in range(1, N_seqs):
    state_seq += [np.random.choice(N_states, p=P[state_seq[-1]])]

label_seq = [np.random.choice(N_labels, p=Pe[state_seq[i]])for i in range(N_seqs)]

print(state_seq[:20])
print(label_seq[:20])
with pm.Model() as model:
    P = pm.Dirichlet('P_transition', a=np.ones((N_states,N_states)), shape=(N_states,N_states))
    
    AA = tt.dmatrix('AA')
    AA = tt.eye(N_states) - P + tt.ones(shape=(N_states,N_states))
    PA = pm.Deterministic('P_start_transition',sla.solve(AA.T,tt.ones(shape=(N_states))))
    
    states = [pm.Categorical("state_0", p=PA, observed=state_seq[0])]
    for i in range(1, len(state_seq)):
        states += [pm.Categorical(
                        "state_%d"%i, 
                        p=P[states[-1]], 
                        observed=state_seq[i]
                    )]
    
    %%time trace = pm.sample(100)

Now, I want to marginalize all the latent states of markov chain. I tried:

with pm.Model() as model:
    comp_dists = []
    for i in range(N_states):
        P = pm.Dirichlet('P_transition_%d'%i, a=np.ones((N_states,)), shape=(N_states,))
        comp_dists += [pm.Categorical("transition_comp_%d"%i, p=P)]
    
    AA = tt.dmatrix('AA')
    AA = tt.eye(N_states) - P + tt.ones(shape=(N_states,N_states))
    PA = pm.Deterministic('P_start_transition',sla.solve(AA.T,tt.ones(shape=(N_states))))

    states = [pm.Categorical("state_0", p=PA, observed=state_seq[0])]
    states += [pm.Mixture(
        "state_1",
        w=PA, 
        comp_dists=comp_dists,
        observed=state_seq[1]
    )]
    for i in range(2, len(state_seq)):
        states += [pm.Mixture(
                        "state_%d"%i,
                        w=P[states[-2]], 
                        comp_dists=comp_dists,
                        observed=state_seq[i]
                    )]
    
    trace = pm.sample(100)

Of course, it did not work. My questions are:

  1. does anyone know how to marginalize them?
  2. I also know another way is to build the logp() of a new HMM class RV. Can this way be able to use Viterbi algorithm? And, what is the speed compared to question 1) approach?

Thank you!


#2

There is more information on the Stan manual (3.6, Semisupervised Estimation). I have not try it myself, but it should work similarly in pymc3.


#3

Hi,

you want to use the Forward Algorithm to marginalize out the latent states and yield the model likelihood. You can also use the Forward-Backward to yield a related quantity known as the ‘Complete Likelihood’, but it’s not necessary. I’ve successfully implemented a hierarchical nonhomogeneous HMM in both pymc3 & stan, ultimately choosing to stick with Stan for performance reasons.

Here’s some valid python pseudo code for the forward algorithm to get you started:

def forward(initial_probabilities, transmat, emission_lattice):
'''
Forward algorithm: computes the log-likelihood of the observed data, given the model.Performs computations in
 log-space to avoid underflow issues. Computes and returns the full forward matrix, and the final
 sum-of-all-paths probabilities.

:initial_probabilities:  Array-like, sequence of floats (log-space)
:transmat: Transmission matrix (row-stochastic matrix) of transition probabilities (log-space)
:emission_lattice: Matrix where the i,jth element is the log-probability of observation i being emitted by state j
:return: float, log-probability (score) of observed sequence relative to model, array-like: full forward matrix
'''
A = transmat
log_init_p = initial_probabilities
B = emission_lattice

T, N = B.shape
alpha = np.zeros((T, N))
# initialize first row of ndarray
alpha[0, :] = log_init_p + B[0, :]
tmp = np.zeros(N)

for t in range(1, T):
    for i in range(0, N):
        for j in range(0, N):
            tmp[j] = alpha[t - 1, j] + A[j, i]
        alpha[t, i] = lse(tmp) + B[t, i]
score = scipy.misc.logsumexp(alpha[-1])
return score, alpha

To implement this as part of your model, you will have to use Theano.Scan(). The forward pass can be implemented in just a few lines of code (I apologise but I can’t share my implementation :frowning:).

A few notes which are very important:

  • All computations must be done in log-space to avoid numerical underflow, which will not only mess up your loglikelihood but result in bizarro gradient issues.
  • Scan() is pretty slow and finicky so be patient with it! I recommend you code up some toy examples of the forward algorithm and compare them against a pure python version to convince yourself you’ve got it right
  • Expect to encounter nasty model identification problems and pathologically multimodal posteriors unless you force constraints upon your problem and/or very informative priors (literally prejudiced priors!).
  • You may recover the complete model posterior marginals for the latent states by sampling from the overall posterior, computing the latent state distributions via fwk-bwk algo from that sample, and then doing that a bunch of times and taking the mean of the results. I was discussing with Richard McElreath @ Stancon Helsinki this year about whether this makes sense and the consensus is that this is the way one goes about doing it for marginalised models.

Good luck!


#4

This is some pretty good tips, thanks for sharing (almost make me forgive you not sharing your implementation :joy:)! I cannot like it enough:

I guess you also see identification problem in Stan model right?


#5

That’s correct - identification problems are really just a feature of the HMMs themselves, not of the implementation per se.

Stan offers some nice tools to help with this like ordered vectors (perhaps there’s an equivalent in pymc3?), but in my particular industrial use-case it was a combination of strong emission priors and restrictions on the permitted transitions which made the model play nice enough to be useful.

You have to be careful because it’s tempting to use ADVI to save time over the agonisingly slow HMC inference of long sequences of data, but ADVI will of course obscure multimodality (and you can’t be sure exactly which way it’s doing this: mode seeking or covering): it will do either depending on the particular geometry of the problem.

With that said, I had some positive results playing with Stein Variational Gradient Descent in pymc3, but it was a bit prone to failing for inscrutable reasons, so I bookmarked it and moved on :expressionless:


#6

Yes we also has that.

Thanks again for your feedback! I also find the out of box VI we have does not handle mixture model well (ADVI and FullRankADVI at least), my reasoning is that some specific VI approximation and fitting algorithm should be used for this.


#7

This is a very informative discussion. I am familiar with Fwd-Bwd algorithm, but never used Stan or Pymc3 before. Look like the biggest challenge for me is the speed. I have hundreds of thousands of sequences to be trained on. Another challenge for me is that this is just my first step to build my full model. I got additional two latent variables in my full model, U and S in the following Plate diagram.

It will increase the complexity of the modeling quickly.

My last question is that what if I only want to have a point estimate (not full Bayesian Posterior), what is my choice? I tried find_MAP() function with my simulated toy example without marginalization. It did not do well.


#8

Look like Pytorch`s dynamic graph approach can reach a better speed for this kind of scan() operation. I will look into it to see the difference.


#9

FYI, I implemented the forward algorithm using theano.scan, notebook here: https://github.com/junpenglao/Planet_Sakaar_Data_Science/blob/master/PyMC3QnA/discourse_2230.ipynb

Observation:

  • unsupervised model is super unidentifiable
  • semi-supervised model seems to perform quite well

@nmrobert, do you see similar result in Stan?


Implementing the random() function for HMM distribution
#10

@junpenglao Nicely done :slight_smile:

Sorry for the delayed response - yes, this is very consistent with my experiences. Even with very strong priors on the transition probabilities you tend to see label switching at best (leading to multimodal, poorly mixed chains), or complete model meltdown at worst.

The model I ended up using was also a semi-supervised one of sort, where I fixed a few of the transition probabilities that were known from domain knowledge and constrained the others quite strongly, and this was enough to make the model play nice.