Updating prior with posterior while using minibatches

Hi,

I am trying to use minibatches for samplers like NUTS. Since pymc3 sample does not support minibatch. I referred Feeding Posterior back in as Prior (updating a model) .

But in my case priors are Dirichlet and categorical:

My model is as follows:

def mix_Bern_fn(data, i, prior = None):
    with bern_model:
        pi = pm.Dirichlet('pi_%d'%(i), a = prior['pi'], shape = K)
        print(pi)
        dri = pm.Dirichlet('dri_%d'%(i), a = prior['dri'], shape = (K, B))
        category_U = pm.Categorical('category_%d'%(i), p = pi, shape = data.shape[0])
        print(category_U.tag.test_value)
        vector_U = pm.Bernoulli('vec_u_%d'%(i), p = dri[category_U], observed = data)  
        
    return bern_model 

if __name__ == '__main__':
    with bern_model:
        pi_init = pm.Dirichlet('pi__x_', a = np.ones(K), shape = K)
        dri_init = pm.Dirichlet('dri__x_', a = np.ones((K, B)), shape = (K, B))
    prior = {'pi': pi_init, 'dri': dri_init} # initial prior
    batcherator = iter_minibatches(chunksize=100)
    i = 0
    for X_chunk in batcherator:
        #minibatch_x = pm.Minibatch(observed, batch_size=100)
        print(X_chunk.shape)
        bern_minibatch = mix_Bern_fn(X_chunk, i, prior)
        posterior = train_model(bern_minibatch,i)
        prior = posterior
        print(i, ':', prior)
        i+= 1

    def train_model(bern_minibatch,i):
    with bern_minibatch:
        tr = pm.sample(100, tune = 80, chains = 1)
    
    return {
        #result.mean.eval()[0],
        'pi' :  tr['pi_%d' %(i)].mean(axis = 0),
        'dri':  tr['dri_%d' %(i)].mean(axis = 0)
    }

Here I used

pi_%d, dri_%d

because I am running a loop for each batch. Problem with this is that when I sample for each batch the number of RV’s goes on increasing because all of them are with the same model context. If the dataset is huge like in my case there will be a lot of batches which means a lot of RV’s so sampling would be lot harder I guess.

Can anyone please suggest how can I tackle this problem?

I tried this model with ADVI using minibatches for my huge dataset but it never really sampled. So thought of this approach.

Help much appreciated.

Thanks in advance.