Writing a Gibbs sampler in PyMC is just as hard and coding it from scratch?

Thanks for the response. I’m just having trouble figuring out what I need to do to implement my own step.

Let me try this again with a much simpler example. Say I want to sample from the joint distribution of p(x,p) with p~Beta(1, 1) and X~Binom(100, p) by sampling p(x|p) ~ Binom(100, p) and p(p|x) ~ Beta(1 + x, 1+100-x) back and forth over and over.

I wrote this:

from pymc3.step_methods.arraystep import BlockedStep

class BetaBinomStep(BlockedStep):
    def __init__(self, var, binom_var, n, alpha, beta):
        self.vars = [var]
        self.var = self.vars[0]
        self.name = var.name
        self.alpha = alpha
        self.beta = beta
        
        self.binom_var = binom_var
        self.binom_n = n

    def step(self, point: dict):
        p_name = self.var.name
        x = point[self.binom_var.name]
        point[self.name] = np.random.beta(self.alpha + x, self.beta + self.binom_n - x)
        
        return point

with pm.Model() as BetaBinom:
    n = 100
    alpha = 1
    beta = 1
    p = pm.Beta('p', alpha=alpha, beta=beta)
    k = pm.Binomial('k', n=n, p=p)
    
    step = BetaBinomStep(p, k, n, alpha, beta)
    
    out_ = pm.sample(10000, step=[BetaBinomStep])

And go this:

/usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
  return wrapped_(*args_, **kwargs_)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>type: []
>NUTS: [p]
>Metropolis: [k]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [129], in <cell line: 21>()
     26 k = pm.Binomial('k', n=n, p=p)
     28 step = BetaBinomStep(p, k, n, alpha, beta)
---> 30 out_ = pm.sample(10000, step=[BetaBinomStep])

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/deprecat/classic.py:215, in deprecat.<locals>.wrapper_function(wrapped_, instance_, args_, kwargs_)
    213         else:
    214             warnings.warn(message, category=category, stacklevel=_routine_stacklevel)
--> 215 return wrapped_(*args_, **kwargs_)

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/sampling.py:575, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, start, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
    573 _print_step_hierarchy(step)
    574 try:
--> 575     trace = _mp_sample(**sample_args, **parallel_args)
    576 except pickle.PickleError:
    577     _log.warning("Could not pickle model, sampling singlethreaded.")

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/sampling.py:1480, in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
   1477         strace.setup(draws + tune, idx + chain)
   1478     traces.append(strace)
-> 1480 sampler = ps.ParallelSampler(
   1481     draws,
   1482     tune,
   1483     chains,
   1484     cores,
   1485     random_seed,
   1486     start,
   1487     step,
   1488     chain,
   1489     progressbar,
   1490     mp_ctx=mp_ctx,
   1491     pickle_backend=pickle_backend,
   1492 )
   1493 try:
   1494     try:

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:431, in ParallelSampler.__init__(self, draws, tune, chains, cores, seeds, start_points, step_method, start_chain_num, progressbar, mp_ctx, pickle_backend)
    428             raise ValueError("dill must be installed for pickle_backend='dill'.")
    429         step_method_pickled = dill.dumps(step_method, protocol=-1)
--> 431 self._samplers = [
    432     ProcessAdapter(
    433         draws,
    434         tune,
    435         step_method,
    436         step_method_pickled,
    437         chain + start_chain_num,
    438         seed,
    439         start,
    440         mp_ctx,
    441         pickle_backend,
    442     )
    443     for chain, seed, start in zip(range(chains), seeds, start_points)
    444 ]
    446 self._inactive = self._samplers.copy()
    447 self._finished = []

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:432, in <listcomp>(.0)
    428             raise ValueError("dill must be installed for pickle_backend='dill'.")
    429         step_method_pickled = dill.dumps(step_method, protocol=-1)
    431 self._samplers = [
--> 432     ProcessAdapter(
    433         draws,
    434         tune,
    435         step_method,
    436         step_method_pickled,
    437         chain + start_chain_num,
    438         seed,
    439         start,
    440         mp_ctx,
    441         pickle_backend,
    442     )
    443     for chain, seed, start in zip(range(chains), seeds, start_points)
    444 ]
    446 self._inactive = self._samplers.copy()
    447 self._finished = []

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:254, in ProcessAdapter.__init__(self, draws, tune, step_method, step_method_pickled, chain, seed, start, mp_ctx, pickle_backend)
    252 self._shared_point = {}
    253 self._point = {}
--> 254 for name, (shape, dtype) in step_method.vars_shape_dtype.items():
    255     size = 1
    256     for dim in shape:

File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/step_methods/compound.py:79, in CompoundStep.vars_shape_dtype(self)
     77 dtype_shapes = {}
     78 for method in self.methods:
---> 79     dtype_shapes.update(method.vars_shape_dtype)
     80 return dtype_shapes

TypeError: 'property' object is not iterable

I’m not getting how I’m supposed to define my own step class. As far as I can tell I’m following the example I linked closely: Using a custom step method for sampling from locally conjugate posterior distributions — PyMC documentation

My understanding is that a step class on initiation gets a variable or list of variables that it applies to. And then in the step method is get’s a point object (a dictionary) with the previous variable’s last value and I need to return the same point object with the new sample in it. Is this wrong?

Thanks again