Combining Theano variables and Numpy arrays in custom step function

Hi all - I’m working on a fairly simple case of implementing a custom step method. I want to sample from a conjugate posterior Dirichlet distribution using a custom step, but I can’t quite figure out how best to combine the Theano variable from the rest of the graph with the Numpy random sampling functions which are available to sample from a Dirichlet. I am aware of the gamma-sum representation of the Dirichlet, but I don’t see any straightforward way to do this in Theano short of implementing my own routines to sample Dirichlet variates. Any ideas or advice? The code below shows my problem.

from pymc3.step_methods.arraystep import BlockedStep

import pymc3 as pm
import numpy as np

class ConjugateDirichletUpdate(BlockedStep):
    
    def __init__(self, var, concentration, counts):
        self.var = var
        self.vars = [var]
        self.conc = concentration
        self.counts = counts
     
    def step(self, point):
        new = point.copy()
        alpha = self.conc + self.counts
        
        #What I want: new[self.var] = np.random.dirichlet(alpha)
        
        # This samples, though it's with the wrong distribution:
        new[self.var] = np.random.dirichlet(self.counts)
        return new
    
    
J = 4
counts = np.sum(np.random.uniform(size=[10,4]) > 0.5, axis=0)

with pm.Model():
    tau = pm.Exponential('tau', lam=1)
    alpha = pm.Deterministic('alpha', tau*np.ones(J))
    p = pm.Dirichlet('p', a=alpha)
    x = pm.Multinomial('x', p=p, n=counts.sum(), observed=counts)
    step = ConjugateDirichletUpdate(p, alpha, counts)
    trace = pm.sample(step=[step], chains=1,cores=1)

I worked out how to do this. Here’s a gist that shows an implementation for a simple case: https://gist.github.com/ckrapu/8be4da91a70763ee62f889c6cd98f700

The basic idea is that instead of manipulating the Theano symbolic variables as they come into the step method, work instead with the point object which contains their numerical values.

1 Like

As a follow up question, I’d like to figure out how to use a Deterministic variable within a step method. Currently, deterministic variables are not recorded in the point object passed to the step method. It’s not clear exactly how the state of those variables can be passed via 'point`.

If you’re interested, here is an example of a conjugate step method for the same underlying Categorical/Dirichlet pairing. In this case, a (transition) matrix with Dirichlet rows is updated according to a matrix of “observed” transition counts.

This step method was recently updated to handle more than just the most basic transition matrix graphs (i.e. stacked rows of Dirichlet distributions), which neatly illustrates the inevitable challenges one faces when creating such step methods. In other words, conjugate step methods face a generalizability challenge that methods like HMC don’t because they have a strong dependency on the exact form of a graph.

The example above addresses some of this but still has large deficiencies. The way to address this more broadly involves improvements to the underlying graph library (i.e. Theano-PyMC) and a framework that fits this type of work (see Symbolic PyMC).

Also, here are some more conjugate step methods in the old PyMC2. They follow roughly the same pattern as PyMC3 step methods—just without the Theano.

3 Likes

Yes, that is precisely the type of example I’m looking for. Related to the pymc3-hmm repository, I’ve often come across use cases where a dynamic Gaussian model / Kalman filter- centric distribution would also be of great use. Thanks for providing these links.