I have built HMM according to thishttps://github.com/hstrey/Hidden-Markov-Models-pymc3.
However, the speed is too slow for my work. The suggestion is to marginalization. I built a normal one for markov chain. It works.
# simulation of data. Two state model for simplicity.
N_seqs = 100
N_labels = 3
N_states = 2
# transition probability
P=np.array([[0.8,0.2],[0.4,0.6]])
# emission probabilityu
Pe = np.array([
[0.8,0.1,0.1],
[0.3,0.4,0.3]
])
N_labels = Pe.shape[1]
AA = np.eye(N_states) - P + np.ones(shape=(N_states,N_states))
PA = np.linalg.solve(AA.T,np.ones(shape=(N_states)))
state_seq = [np.random.choice(N_states, p=PA)]
for i in range(1, N_seqs):
state_seq += [np.random.choice(N_states, p=P[state_seq[-1]])]
label_seq = [np.random.choice(N_labels, p=Pe[state_seq[i]])for i in range(N_seqs)]
print(state_seq[:20])
print(label_seq[:20])
with pm.Model() as model:
P = pm.Dirichlet('P_transition', a=np.ones((N_states,N_states)), shape=(N_states,N_states))
AA = tt.dmatrix('AA')
AA = tt.eye(N_states) - P + tt.ones(shape=(N_states,N_states))
PA = pm.Deterministic('P_start_transition',sla.solve(AA.T,tt.ones(shape=(N_states))))
states = [pm.Categorical("state_0", p=PA, observed=state_seq[0])]
for i in range(1, len(state_seq)):
states += [pm.Categorical(
"state_%d"%i,
p=P[states[-1]],
observed=state_seq[i]
)]
%%time trace = pm.sample(100)
Now, I want to marginalize all the latent states of markov chain. I tried:
with pm.Model() as model:
comp_dists = []
for i in range(N_states):
P = pm.Dirichlet('P_transition_%d'%i, a=np.ones((N_states,)), shape=(N_states,))
comp_dists += [pm.Categorical("transition_comp_%d"%i, p=P)]
AA = tt.dmatrix('AA')
AA = tt.eye(N_states) - P + tt.ones(shape=(N_states,N_states))
PA = pm.Deterministic('P_start_transition',sla.solve(AA.T,tt.ones(shape=(N_states))))
states = [pm.Categorical("state_0", p=PA, observed=state_seq[0])]
states += [pm.Mixture(
"state_1",
w=PA,
comp_dists=comp_dists,
observed=state_seq[1]
)]
for i in range(2, len(state_seq)):
states += [pm.Mixture(
"state_%d"%i,
w=P[states[-2]],
comp_dists=comp_dists,
observed=state_seq[i]
)]
trace = pm.sample(100)
Of course, it did not work. My questions are:
- does anyone know how to marginalize them?
- I also know another way is to build the logp() of a new HMM class RV. Can this way be able to use Viterbi algorithm? And, what is the speed compared to question 1) approach?
Thank you!