Thanks for the response. I’m just having trouble figuring out what I need to do to implement my own step.
Let me try this again with a much simpler example. Say I want to sample from the joint distribution of p(x,p) with p~Beta(1, 1) and X~Binom(100, p) by sampling p(x|p) ~ Binom(100, p) and p(p|x) ~ Beta(1 + x, 1+100-x) back and forth over and over.
I wrote this:
from pymc3.step_methods.arraystep import BlockedStep
class BetaBinomStep(BlockedStep):
def __init__(self, var, binom_var, n, alpha, beta):
self.vars = [var]
self.var = self.vars[0]
self.name = var.name
self.alpha = alpha
self.beta = beta
self.binom_var = binom_var
self.binom_n = n
def step(self, point: dict):
p_name = self.var.name
x = point[self.binom_var.name]
point[self.name] = np.random.beta(self.alpha + x, self.beta + self.binom_n - x)
return point
with pm.Model() as BetaBinom:
n = 100
alpha = 1
beta = 1
p = pm.Beta('p', alpha=alpha, beta=beta)
k = pm.Binomial('k', n=n, p=p)
step = BetaBinomStep(p, k, n, alpha, beta)
out_ = pm.sample(10000, step=[BetaBinomStep])
And go this:
/usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.
return wrapped_(*args_, **kwargs_)
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>type: []
>NUTS: [p]
>Metropolis: [k]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [129], in <cell line: 21>()
26 k = pm.Binomial('k', n=n, p=p)
28 step = BetaBinomStep(p, k, n, alpha, beta)
---> 30 out_ = pm.sample(10000, step=[BetaBinomStep])
File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/deprecat/classic.py:215, in deprecat.<locals>.wrapper_function(wrapped_, instance_, args_, kwargs_)
213 else:
214 warnings.warn(message, category=category, stacklevel=_routine_stacklevel)
--> 215 return wrapped_(*args_, **kwargs_)
File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/sampling.py:575, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, start, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
573 _print_step_hierarchy(step)
574 try:
--> 575 trace = _mp_sample(**sample_args, **parallel_args)
576 except pickle.PickleError:
577 _log.warning("Could not pickle model, sampling singlethreaded.")
File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/sampling.py:1480, in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1477 strace.setup(draws + tune, idx + chain)
1478 traces.append(strace)
-> 1480 sampler = ps.ParallelSampler(
1481 draws,
1482 tune,
1483 chains,
1484 cores,
1485 random_seed,
1486 start,
1487 step,
1488 chain,
1489 progressbar,
1490 mp_ctx=mp_ctx,
1491 pickle_backend=pickle_backend,
1492 )
1493 try:
1494 try:
File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:431, in ParallelSampler.__init__(self, draws, tune, chains, cores, seeds, start_points, step_method, start_chain_num, progressbar, mp_ctx, pickle_backend)
428 raise ValueError("dill must be installed for pickle_backend='dill'.")
429 step_method_pickled = dill.dumps(step_method, protocol=-1)
--> 431 self._samplers = [
432 ProcessAdapter(
433 draws,
434 tune,
435 step_method,
436 step_method_pickled,
437 chain + start_chain_num,
438 seed,
439 start,
440 mp_ctx,
441 pickle_backend,
442 )
443 for chain, seed, start in zip(range(chains), seeds, start_points)
444 ]
446 self._inactive = self._samplers.copy()
447 self._finished = []
File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:432, in <listcomp>(.0)
428 raise ValueError("dill must be installed for pickle_backend='dill'.")
429 step_method_pickled = dill.dumps(step_method, protocol=-1)
431 self._samplers = [
--> 432 ProcessAdapter(
433 draws,
434 tune,
435 step_method,
436 step_method_pickled,
437 chain + start_chain_num,
438 seed,
439 start,
440 mp_ctx,
441 pickle_backend,
442 )
443 for chain, seed, start in zip(range(chains), seeds, start_points)
444 ]
446 self._inactive = self._samplers.copy()
447 self._finished = []
File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/parallel_sampling.py:254, in ProcessAdapter.__init__(self, draws, tune, step_method, step_method_pickled, chain, seed, start, mp_ctx, pickle_backend)
252 self._shared_point = {}
253 self._point = {}
--> 254 for name, (shape, dtype) in step_method.vars_shape_dtype.items():
255 size = 1
256 for dim in shape:
File /usr/local/Caskroom/miniconda/base/lib/python3.9/site-packages/pymc3/step_methods/compound.py:79, in CompoundStep.vars_shape_dtype(self)
77 dtype_shapes = {}
78 for method in self.methods:
---> 79 dtype_shapes.update(method.vars_shape_dtype)
80 return dtype_shapes
TypeError: 'property' object is not iterable
I’m not getting how I’m supposed to define my own step class. As far as I can tell I’m following the example I linked closely: Using a custom step method for sampling from locally conjugate posterior distributions — PyMC documentation
My understanding is that a step class on initiation gets a variable or list of variables that it applies to. And then in the step method is get’s a point object (a dictionary) with the previous variable’s last value and I need to return the same point object with the new sample in it. Is this wrong?
Thanks again