Hi,
I am currently trying to implement a LDA model with 2 latent variables in PyMC3 using AEVB.
The model is used to cluster gene distributions for different type of cells and has 2 latent variables: transcriptional states and topics. Genes and cells resembles the words and documents, and the transcripts resembles the tokens in original LDA model (Blei, et al).
The model specification is as following:
Each cell has a multinomial distribution over topics (latent variable 1), each topic has a multinomial distribution over all transcriptional states (latent variable 2), and each transcriptional state has a multinomial distribution over genes.
One thing to notice is that a single cell mainly has one topic, and a single gene mainly contributes to one transcriptional state. These 2 requirements can be easily checked from the estimated probability (e.g. probability> 0.9).
The dataset is a cell by gene matrix, with each element of the matrix being number of transcript for the gene in that cell.
The parameters specified and the log-likelihood is calculated as:
def logp_lda_doc(theta, phi, psi):
"""return the loglikelihood of given gene expressions """
Z: number of topics in the model
M: number of cells
K: number of transcriptional states / cellular subpopulation
N: number of genes
parameters
----------
theta: tensor (M x Z)
Topic distribution for cell
phi: tensor (Z x K)
Transcriptional states distribution within each topic
psi: tensor (K x N)
Gene distributions within each transcriptional state.
"""
def ll_docs_f(docs):
dixs, vixs = docs.nonzero()
vfreqs = docs[dixs, vixs]
# loop the inner logsumexp over each topic (inner logsumexp: within each topic do the logsumexp), and the results are stored into reuslts1
results1, updates1 = theano.scan(lambda phi,psi,vixs:
pmmath.logsumexp(tt.log(tt.tile(phi, (vixs.shape[0],1))) + tt.log(psi.T[vixs]), axis=1).ravel(),
sequences=phi,
non_sequences=[psi,vixs])
ll_docs = vfreqs * pmmath.logsumexp(
tt.log(theta[dixs])+results1.T, axis=1).ravel()
return tt.sum(ll_docs)
return ll_docs_f
The AEVB part is what I am struggling with. I have briefly went through the papers about AEVB and ADVI and got basic understanding of it but haven’t totally digested them all so still have problem for implementation.
I was following the PyMC3 LDA documentation, and tried to make adjustment based on the example code. But as only one latent variable can be specified for ADVI in PyMC, I only add more hidden layers for the encoder and specify one of the hidden layer has 2*(#transcriptional states -1) neurons, and keep the output layer has 2*(#topics -1) neurons:
class LDAEncoder:
def __init__(self, n_genes, n_hidden1, n_hidden2, n_topics, n_subppl, p_corruption=0, random_seed=1):
rng = np.random.RandomState(random_seed)
self.n_genes = n_genes
self.n_hidden1 = n_hidden1
self.n_hidden2 = n_hidden2
self.n_topics = n_topics
self.n_subppl = n_subppl # transcriptional states
self.w0 = shared(0.01 * rng.randn(n_genes, n_hidden1).ravel(), name='w0')
self.b0 = shared(0.01 * rng.randn(n_hidden1), name='b0')
self.w1 = shared(0.01 * rng.randn(n_hidden1, 2 * (n_subppl - 1)).ravel(), name='w1')
self.b1 = shared(0.01 * rng.randn(2 * (n_subppl - 1)), name='b1')
self.w2 = shared(0.01 * rng.randn(2 * (n_subppl - 1), n_hidden2).ravel(), name='w2')
self.b2 = shared(0.01 * rng.randn(n_hidden2), name='b2')
self.w3 = shared(0.01 * rng.randn(n_hidden2, 2 * (n_topics - 1)).ravel(), name='w3')
self.b3 = shared(0.01 * rng.randn(2 * (n_topics - 1)), name='b3')
self.rng = MRG_RandomStreams(seed=random_seed)
self.p_corruption = p_corruption
def encode(self, xs):
if 0 < self.p_corruption:
dixs, vixs = xs.nonzero()
mask = tt.set_subtensor(
tt.zeros_like(xs)[dixs, vixs],
self.rng.binomial(size=dixs.shape, n=1, p=1-self.p_corruption))
xs_ = xs * mask
else:
xs_ = xs
w0 = self.w0.reshape((self.n_genes, self.n_hidden1))
w1 = self.w1.reshape((self.n_hidden1, 2 * (self.n_subppl -1) ))
w2 = self.w2.reshape((2 * (self.n_subppl -1) , self.n_hidden2))
w3 = self.w3.reshape((self.n_hidden2, 2 * (n_topics - 1) ))
h1s = tt.tanh(xs_.dot(w0) + self.b0)
ks = h1s.dot(w1) +self.b1
h2s = tt.tanh(ks.dot(w2) + self.b2)
zs = h2s.dot(w3) + self.b3
zs_mean = zs[:, :(self.n_topics - 1)]
zs_std = zs[:, (self.n_topics - 1):]
return zs_mean, zs_std #, ks_mean, ks_std
def get_params(self):
return [self.w0, self.b0, self.w1, self.b1, self.w2, self.b2 , self.w3, self.b3]
(Also since I am using generated data, I know the true topic for each cell and true transcriptional state for each gene.) The result is not too bad, but far away from good. By comparing to the true parameters, it can identify the cell topics with 100% accuracy, but only around 60% accuracy for gene transcriptional states.
My encoder specification is sure problematic as it is based more on my intuition. I would appreciate any suggestion or any recommendations for the materials I can refer to.
Thank you very much and kind regards.