Running with minibatches (memory constraints)

Just as a bit more info, it’s supposed to be a LLDA model that looks something like this;

import pymc3 as pm
import numpy as np
import scipy.sparse as sps
import theano.tensor as tt
from pymc3.distributions.transforms import t_stick_breaking
from theano import shared
import theano
theano.config.compute_test_value = ‘off’

class LLDA_model_pymc3:
“”" takes in sparse matrix of feature vectors and a dataframe of labels “”"

def __init__(self, word_counts, feature_names, labels):
    self.wordCounts = word_counts
    self.feature_names = feature_names
    self.labels = labels
    self.nTopics = labels.shape[1]  # K
    self.vocabLen = word_counts.shape[1] # V
    self.nDocs = word_counts.shape[0]    # D
    self.nTokens = np.sum(word_counts[word_counts.nonzero()])

def build_pymc3_model(self, minibatchSize=200):
    self.minibatchSize = minibatchSize
    
    def logp_lda_doc(beta, theta):
        """Returns the log-likelihood function for given documents. 
        K : number of topics in the model
        V : number of words (size of vocabulary)
        D : number of documents (in a mini-batch)
        Parameters
        ----------
        beta : tensor (K x V)
            Word distributions. 
        theta : tensor (D x K)
            Topic distributions for documents (set as strong Dirichlet for supervised model) 
        """
        def docLiklihoodFunction(docs):
            documentIndex, vocabIndex = docs.nonzero()
            vocabFreqs = docs[documentIndex, vocabIndex]
            docLikelihood = vocabFreqs * pm.math.logsumexp(
                tt.log(theta[documentIndex]) + tt.log(beta.T[vocabIndex]), axis=1).ravel()

            # per-word log-likelihood * num of tokens in the whole dataset
            return tt.sum(docLikelihood) / tt.sum(vocabFreqs) * self.nTokens 

        return docLiklihoodFunction
    
    self.doc_t_minibatch = pm.Minibatch(self.wordCounts.toarray(), minibatchSize)
    self.doc_t = shared(self.wordCounts.toarray()[:minibatchSize], borrow=True)
    self.topic_t = shared(np.asarray(self.labels)[:minibatchSize], borrow=True)
    self.topic_t_minibatch = pm.Minibatch(np.asarray(self.labels), minibatchSize)

    with pm.Model() as model:
        beta = pm.Dirichlet('beta', a=pm.floatX((1.0 / self.nTopics) * np.ones((self.nTopics, self.vocabLen))),
                         shape=(self.nTopics, self.vocabLen), transform=t_stick_breaking(1e-9))
        doc = pm.DensityDist('doc', logp_lda_doc(beta, self.topic_t), observed=self.doc_t)
    
    self.model = model

def inference(self, n_steps = 10000, start_learn_rate = 0.1):
    try:
        self.model
    except: 
        print("No pymc model has been defined")
    else:
        n = start_learn_rate
        s = shared(n)
        def reduce_rate(a, h, i):
            s.set_value(n/((i/self.minibatchSize)+1)**.7)

        with self.model:
            approx = pm.MeanField()
            approx.scale_cost_to_minibatch = False
            inference = pm.KLqp(approx)

        inference.fit(n_steps, callbacks=[reduce_rate], obj_optimizer=pm.sgd(learning_rate=s),
                      total_grad_norm_constraint=200,
                      more_replacements={self.doc_t:self.doc_t_minibatch, self.topic_t:self.topic_t_minibatch})

        self.approx = approx
        
        samples = pm.sample_approx(approx, draws=100)
        self.vocab_samples = samples['beta'].mean(axis=0)
    
def print_top_words(self, n_top_words=10):
    try:
        self.vocab_samples
    except:
        print("Error, build model + perform inference first")
    else:
        for i in range(len(self.vocab_samples)):
            print(("Topic #%d: " % i) + " ".join([self.feature_names[j]
                for j in self.vocab_samples[i].argsort()[:-n_top_words - 1:-1]]))
        
def predictions(self, test_word_counts, softmax = True):
    
    def softmax(x):
        e_x = np.exp(x - np.max(x, axis=1)[:, None])
        return e_x/e_x.sum(axis=1)[:, None]
    
    try:
        self.vocab_samples
    except:
        print("Error, build model + perform inference first")
    else:
        predictions = test_word_counts.dot(self.vocab_samples.transpose())
        if softmax:
            predictions = softmax(predictions)
        return(predictions)`