Writing a Gibbs sampler in PyMC is just as hard and coding it from scratch?

Hello, I’m trying to implement a custom Gibbs sampler in PyMC3. I can’t figure out a way to specify my sampler that’s simple and idiomatic and I’m wondering if I’m missing the right way to do it. Seems like Gibbs sampling isn’t what PyMC is designed for so maybe that’s it.

Below is some code I wrote without PyMC that implements a Gibbs sampler for the posterior of population genetics parameters f and r given observations of organisms with different genotypes (AA, Aa or aa). Basically you sample a latent variable Z (whether or not an observation is inbred) conditioned on f and r and then you sample f and r conditioned on Z by taking advantage of the Beta-Binomial conjugacy:

import numpy as np
np.random.seed(123)

def gibbs_hw(niters, data, 
             prior_params=[1,1,1,1], 
             initial_values = {'f': 0.5, 'r': 0.5}
            ):
    
    # Turn counts into list of strings
    obs = np.array(['AA']*data[0] + ['Aa']*data[1] + ['aa']*data[2])
    
    f = np.zeros((niters))
    r = np.zeros_like(f)
    f[0] = initial_values['f']
    r[0] = initial_values['r']
    
    for i in range(1, niters):
        # Z_i is whether the ith observation is inbred or not.
        # g_i is the genotype of the ith individual
        # Calculate p(Z|f,r) for each case of g_i:
        zi_prob_map = {
            'AA': f[i-1] * r[i-1] / (f[i-1] * r[i-1] + (1 - f[i-1]) * r[i-1] ** 2),
            'Aa': 0,
            'aa': f[i-1] * (1 - r[i-1]) / (f[i-1] * (1 - r[i-1]) + (1 - f[i-1]) * (1 - r[i-1]) ** 2)
        }
        
        z_probs = np.array([zi_prob_map[key] for key in obs])        
        z = np.random.uniform(size = z_probs.size) < z_probs
        n_ibd = z.sum()
        n_not_ibd = (~z).sum()
        
        f[i] = np.random.beta(n_ibd + prior_params[0], n_not_ibd + prior_params[1], size=1)
        
        # Get counts of genotypes given NOT inbred. 
        types, not_idb_type_counts = np.unique(obs[~z], return_counts=True)
        not_ibd_counts = defaultdict(lambda :0, zip(types, not_idb_type_counts))
        nz_A = 2 * not_ibd_counts["AA"] + not_ibd_counts["Aa"] 
        nz_a = 2 * not_ibd_counts["aa"] + not_ibd_counts["Aa"]

        # Get counts of genotypes given  inbred.
        types, idb_type_counts = np.unique(obs[z], return_counts=True)
        ibd_counts = defaultdict(lambda :0, zip(types, idb_type_counts))
        z_A = ibd_counts["AA"]
        z_a = ibd_counts["aa"]

        r[i] = np.random.beta(prior_params[2] + nz_A + z_A, prior_params[3] + nz_a + z_a, size=1)
    
    return{
        'f': f,
        'r': r
    }
        
        
out = gibbs_hw(niters=10000, data=(50, 21, 29))
plt.hist2d(out['f'], out['r'], bins=75)
plt.show()

I would like to implement this in PyMC3 but I can’t figure out how. This posterior can be sampled in a different way using Metropolis or NUTS like so:

counts=np.array([50, 21, 29])
data_enum = {
    0: 'AA',
    1: 'Aa',
    2: 'aa'
}

data = np.array([0]*counts[0] + [1]*counts[1] + [2]*counts[2])

with pm.Model() as hardy_weinberg:
    f = pm.Beta('f', alpha=1, beta=1) # Uniform
    r = pm.Beta('r', alpha=1, beta=1)
    param1 = f*r+(1-r)*(r**2)
    param2 = 2*(1-f)*r*(1-r)
    param3 = f*(1-r)+(1-f)*(1-r)
    genotype = pm.Categorical('genotype', p=pm.math.stack(param1, param2, param3), 
                              observed=data)    

So that’s pretty cool. But it’s not obvious to me how I’m supposed to implement a Gibbs sampler where the graph is cyclic. As in Z is dependent on f,r and f,r are dependent on Z. I’m getting the impression that this isn’t one of the main use cases of PyMC is designed for.

I found this page on making custom step classes: Using a custom step method for sampling from locally conjugate posterior distributions — PyMC documentation

But looking at the code, it looks like this just amounts to me writing it from scratch again using the numpy random samplers. Am I missing something with the available step classes?

PS: I love this community and this package! PyMC is sick. I hope I can contribute a good example or tutorial for other beginners.

Your inner loop would probably be what goes into the PyMC stepper yes.

What exactly are you concerned about?

My suggestion would be to write it first and then see if something can be rewritten more idiomatically, instead of worrying about that first.

The generative (PyMC) graph can’t be cyclic, but neither can a Numpy expression so I don’t think you mean exactly that. Your step sampler can have two stages of dependency just fine however. The only thing PyMC cares about is that it returns the next accepted values after each step. What is done to achieve that is up to the step method.

1 Like

Thanks for the response. I’m just having trouble figuring out what I need to do to implement my own step.

Let me try this again with a much simpler example. Say I want to sample from the joint distribution of p(x,p) with p~Beta(1, 1) and X~Binom(100, p) by sampling p(x|p) ~ Binom(100, p) and p(p|x) ~ Beta(1 + x, 1+100-x) back and forth over and over.

I wrote this:

from pymc3.step_methods.arraystep import BlockedStep

class BetaBinomStep(BlockedStep):
    def __init__(self, var, binom_var, n, alpha, beta):
        self.vars = [var]
        self.var = self.vars[0]
        self.name = var.name
        self.alpha = alpha
        self.beta = beta
        
        self.binom_var = binom_var
        self.binom_n = n

    def step(self, point: dict):
        p_name = self.var.name
        x = point[self.binom_var.name]
        point[self.name] = np.random.beta(self.alpha + x, self.beta + self.binom_n - x)
        
        return point

with pm.Model() as BetaBinom:
    n = 100
    alpha = 1
    beta = 1
    p = pm.Beta('p', alpha=alpha, beta=beta)
    k = pm.Binomial('k', n=n, p=p)
    
    step = BetaBinomStep(p, k, n, alpha, beta)
    
    out_ = pm.sample(10000, step=[BetaBinomStep])

And go this:

/usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>type: []
>NUTS: [p]
>Metropolis: [k]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [129], in <cell line: 21>()
     26 k = pm.Binomial('k', n=n, p=p)
     28 step = BetaBinomStep(p, k, n, alpha, beta)
---> 30 out_ = pm.sample(10000, step=[BetaBinomStep])

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/deprecat/classic.py:215, in deprecat.<locals>.wrapper_function(wrapped_, instance_, args_, kwargs_)
    213         else:
    214             warnings.warn(message, category=category, stacklevel=_routine_stacklevel)
--> 215 return wrapped_(*args_, **kwargs_)

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/sampling.py:575, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, start, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    573 _print_step_hierarchy(step)
    574 try:
--> 575     trace = _mp_sample(**sample_args, **parallel_args)
    576 except pickle.PickleError:
    577     _log.warning("Could not pickle model, sampling singlethreaded.")

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/sampling.py:1480, in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1477         strace.setup(draws + tune, idx + chain)
   1478     traces.append(strace)
-> 1480 sampler = ps.ParallelSampler(
   1481     draws,
   1482     tune,
   1483     chains,
   1484     cores,
   1485     random_seed,
   1486     start,
   1487     step,
   1488     chain,
   1489     progressbar,
   1490     mp_ctx=mp_ctx,
   1491     pickle_backend=pickle_backend,
   1492 )
   1493 try:
   1494     try:

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:431, in ParallelSampler.__init__(self, draws, tune, chains, cores, seeds, start_points, step_method, start_chain_num, progressbar, mp_ctx, pickle_backend)
    428             raise ValueError("dill must be installed for pickle_backend='dill'.")
    429         step_method_pickled = dill.dumps(step_method, protocol=-1)
--> 431 self._samplers = [
    432     ProcessAdapter(
    433         draws,
    434         tune,
    435         step_method,
    436         step_method_pickled,
    437         chain + start_chain_num,
    438         seed,
    439         start,
    440         mp_ctx,
    441         pickle_backend,
    442     )
    443     for chain, seed, start in zip(range(chains), seeds, start_points)
    444 ]
    446 self._inactive = self._samplers.copy()
    447 self._finished = []

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:432, in <listcomp>(.0)
    428             raise ValueError("dill must be installed for pickle_backend='dill'.")
    429         step_method_pickled = dill.dumps(step_method, protocol=-1)
    431 self._samplers = [
--> 432     ProcessAdapter(
    433         draws,
    434         tune,
    435         step_method,
    436         step_method_pickled,
    437         chain + start_chain_num,
    438         seed,
    439         start,
    440         mp_ctx,
    441         pickle_backend,
    442     )
    443     for chain, seed, start in zip(range(chains), seeds, start_points)
    444 ]
    446 self._inactive = self._samplers.copy()
    447 self._finished = []

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:254, in ProcessAdapter.__init__(self, draws, tune, step_method, step_method_pickled, chain, seed, start, mp_ctx, pickle_backend)
    252 self._shared_point = {}
    253 self._point = {}
--> 254 for name, (shape, dtype) in step_method.vars_shape_dtype.items():
    255     size = 1
    256     for dim in shape:

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/step_methods/compound.py:79, in CompoundStep.vars_shape_dtype(self)
     77 dtype_shapes = {}
     78 for method in self.methods:
---> 79     dtype_shapes.update(method.vars_shape_dtype)
     80 return dtype_shapes

TypeError: 'property' object is not iterable

I’m not getting how I’m supposed to define my own step class. As far as I can tell I’m following the example I linked closely: Using a custom step method for sampling from locally conjugate posterior distributions — PyMC documentation

My understanding is that a step class on initiation gets a variable or list of variables that it applies to. And then in the step method is get’s a point object (a dictionary) with the previous variable’s last value and I need to return the same point object with the new sample in it. Is this wrong?

Thanks again