AEVB Implementation Question: LDA model with 2 latent variables using AEVB



I am currently trying to implement a LDA model with 2 latent variables in PyMC3 using AEVB.
The model is used to cluster gene distributions for different type of cells and has 2 latent variables: transcriptional states and topics. Genes and cells resembles the words and documents, and the transcripts resembles the tokens in original LDA model (Blei, et al).

The model specification is as following:
Each cell has a multinomial distribution over topics (latent variable 1), each topic has a multinomial distribution over all transcriptional states (latent variable 2), and each transcriptional state has a multinomial distribution over genes.
One thing to notice is that a single cell mainly has one topic, and a single gene mainly contributes to one transcriptional state. These 2 requirements can be easily checked from the estimated probability (e.g. probability> 0.9).

The dataset is a cell by gene matrix, with each element of the matrix being number of transcript for the gene in that cell.

The parameters specified and the log-likelihood is calculated as:

def logp_lda_doc(theta, phi, psi):
    """return the loglikelihood of given gene expressions """
    Z: number of topics in the model 
    M: number of cells 
    K: number of transcriptional states / cellular subpopulation
    N: number of genes 
    theta:  tensor (M x Z)
      Topic distribution for cell 
    phi:  tensor (Z x K) 
      Transcriptional states distribution within each topic
    psi: tensor (K x N)
      Gene distributions within each transcriptional state.
    def ll_docs_f(docs):
        dixs, vixs = docs.nonzero()
        vfreqs = docs[dixs, vixs]
        #  loop the inner logsumexp over each topic  (inner logsumexp: within each topic do the logsumexp), and the results are stored into reuslts1
        results1, updates1 = theano.scan(lambda phi,psi,vixs: 
                             pmmath.logsumexp(tt.log(tt.tile(phi, (vixs.shape[0],1))) + tt.log(psi.T[vixs]), axis=1).ravel(), 
        ll_docs = vfreqs * pmmath.logsumexp(
            tt.log(theta[dixs])+results1.T, axis=1).ravel()
        return tt.sum(ll_docs)
    return ll_docs_f

The AEVB part is what I am struggling with. I have briefly went through the papers about AEVB and ADVI and got basic understanding of it but haven’t totally digested them all so still have problem for implementation.
I was following the PyMC3 LDA documentation, and tried to make adjustment based on the example code. But as only one latent variable can be specified for ADVI in PyMC, I only add more hidden layers for the encoder and specify one of the hidden layer has 2*(#transcriptional states -1) neurons, and keep the output layer has 2*(#topics -1) neurons:

class LDAEncoder:
    def __init__(self, n_genes, n_hidden1, n_hidden2, n_topics, n_subppl, p_corruption=0, random_seed=1):
        rng = np.random.RandomState(random_seed)
        self.n_genes = n_genes
        self.n_hidden1 = n_hidden1
        self.n_hidden2 = n_hidden2
        self.n_topics = n_topics
        self.n_subppl = n_subppl   # transcriptional states
        self.w0 = shared(0.01 * rng.randn(n_genes, n_hidden1).ravel(), name='w0')
        self.b0 = shared(0.01 * rng.randn(n_hidden1), name='b0')
        self.w1 = shared(0.01 * rng.randn(n_hidden1, 2 * (n_subppl - 1)).ravel(), name='w1')
        self.b1 = shared(0.01 * rng.randn(2 * (n_subppl - 1)), name='b1')
        self.w2 = shared(0.01 * rng.randn(2 * (n_subppl - 1), n_hidden2).ravel(), name='w2')
        self.b2 = shared(0.01 * rng.randn(n_hidden2), name='b2')
        self.w3 = shared(0.01 * rng.randn(n_hidden2, 2 * (n_topics - 1)).ravel(), name='w3')
        self.b3 = shared(0.01 * rng.randn(2 * (n_topics - 1)), name='b3')
        self.rng = MRG_RandomStreams(seed=random_seed)
        self.p_corruption = p_corruption
    def encode(self, xs):
        if 0 < self.p_corruption:
            dixs, vixs = xs.nonzero()
            mask = tt.set_subtensor(
                tt.zeros_like(xs)[dixs, vixs],
                self.rng.binomial(size=dixs.shape, n=1, p=1-self.p_corruption))
            xs_ = xs * mask
            xs_ = xs
        w0 = self.w0.reshape((self.n_genes, self.n_hidden1))
        w1 = self.w1.reshape((self.n_hidden1, 2 * (self.n_subppl -1) ))
        w2 = self.w2.reshape((2 * (self.n_subppl -1) , self.n_hidden2))   
        w3 = self.w3.reshape((self.n_hidden2, 2 * (n_topics - 1) ))  
        h1s = tt.tanh( + self.b0)
        ks = +self.b1
        h2s = tt.tanh( + self.b2)
        zs = + self.b3
        zs_mean = zs[:, :(self.n_topics - 1)]
        zs_std = zs[:, (self.n_topics - 1):]
        return zs_mean, zs_std #, ks_mean, ks_std
    def get_params(self):
        return [self.w0, self.b0, self.w1, self.b1, self.w2, self.b2 , self.w3, self.b3]

(Also since I am using generated data, I know the true topic for each cell and true transcriptional state for each gene.) The result is not too bad, but far away from good. By comparing to the true parameters, it can identify the cell topics with 100% accuracy, but only around 60% accuracy for gene transcriptional states.

My encoder specification is sure problematic as it is based more on my intuition. I would appreciate any suggestion or any recommendations for the materials I can refer to.

Thank you very much and kind regards.


This is an interesting problem - do you have a notebook and data you can share?


This is a project I am working with a professor, I may not be able to provide notes for the model as it is still in progress. But I can share the data and the code to run the model including results comparison (celda_CG_validation.ipynb file on my github)

simCG_counts_Z5K6_copy.npy (3.4 MB)

simCG_z_Z5K6_copy.npy (3.5 KB)

simCG_y_Z5K6_copy.npy (7.9 KB)

simCG_counts_Z5K6_copy.npy file is the cell by gene count matrix (444 x 999)
simCG_z_Z5K6_copy.npy file is a 444 length vector containing true topic for each cell
simCG_y_Z5K6_copy.npy file is a 999 length vector containing true transcriptional states for each gene

Thank you for willing to help!

btw, in case anyone is interested in how these data are generated, it is available from an R package built by our team.


I play around with the code a little bit, while I don’t have any idea of why the model is underperforming, the setup of the approximation neural net and the logp seems fine to me. Please find the secret gist of some additional model checking.

What I try is:
1, standardizing the input matrix, doesnt seems to improve much here
2, use minibatch

Hope you find it some how helps :slight_smile:


Thank you very much for trying it out!

The checking plots are great help.