Creating A custom step function

I would like to write my own sampler, specifically for the Reversible Jump Markov Chain Monte Carlo algorithm. To do so, I’m trying to find guides on samplers and step functions in PyMC3.

So far I’ve found found this guide which talks about it. But it’s more of a recipe which doesn’t really tell me what the thinking is behind samplers and step functions in PyMC3.

Do you guys have any material that talks more about how this part of the API is conceptually organized?

1 Like

If we don’t have we should definitely make one! If nobody replies here for a while I would suggest you open an issue on Github requesting new documentation exactly for this.

There is also some information in Compound Steps in Sampling — PyMC3 3.11.4 documentation

In general, you can inherent from ArrayStep, and implemented the astep method:

You can take a look at some older PR where a new MCMC step method is added: Add DEMetropolisZ stepper by michaelosthege · Pull Request #3784 · pymc-devs/pymc · GitHub (be aware of the style and syntax change in v4 tho)

Thanks for the article. The picture is getting clearer but there are things that are still confusing to me.
In the astep function we have apoint and point as arguements.
But when I look at the overridden version in Metropolis, it only uses one argument

astep(self, q0: RaveledVars) 

So which point type should I be using, and what’s expected to be in there ?

q0 is a RaveledVars in this case, which is the concatenate+flatten representation of the free parameter return by PyMC internal during sampling. You should be using this indeed.

So I’ve tried to interpret the Metropolis step function as best I can and came up with the following implementation.

The probabilistic model has continuous parameters k as discrete parameters \Delta, that indicate whether or not the parameter k should be active. (In this way I can give the model with a k=0 a non 0 probability measure)

The code then decides with probability p_delta_flip to switch on a delta and randomly choose a value for the k that is now active.
If it doesn’t choose to flip a delta, then it only needs to adjust the k parameters that were previously active. It should do this using one of the already existing step functions.

I tried to use the preexisting step functions by instantiating them during __init__ for each \Delta configuration that could arise.

Would you mind giving me some commentary just on the general interpretations that I made in this code? like how to use other step functions, and am I using the stats property correctly. stuff like that. thanks

code below (more pseudocode):

class RJMCMC(Metropolis):
    """
    Represents a pymc3 step function for rjmcmc algorithm
    """

    def __init__(self, var_assoc, model=None, **kwargs):
        """
        var_assoc: [(delta, k)] list of variables with the switch variables
        delta = 0 or 1 indicates if the parameter is active or not
        k is a continuous parameter of the distribution
        when a k has it's delta set to 0, there should be no stepping in it's dimension
        This algorithm hopes to reduce the average dimension where stepping happens, hopefully speeding things up
        """
        deltas = [d for d,_ in var_assoc]
        ks = [k for _,k in var_assoc]
        vars = deltas + ks # Everything together
        super.__init__(vars, tune=False, model=model, **kwargs)

        self.associations = {x.name:y.name for x, y in var_assoc}
        self.deltas       = {d.name:d for d, _ in var_assoc}
        self.ks           = {k.name:k for _, k in var_assoc}

        self.fixed_delta_order = list(self.deltas)

        # Outsource stepping within a space to some other step functions
        # For the moment I'll use slice since no gradient on logp as of yet
        self.subspace_steppers = {}
        for combo in itertools.combinations_with_repetition((0,1), len(deltas)):
            selected = [x for s, x in zip(combo, self.fixed_delta_order) if s == 1]

            selected_vars = [x for x in self.vars if x.name in selected]

            self.subspace_steppers[self.get_config_id_from_combo(combo)] = Slice(selected_vars)

    def get_config_id_from_combo(self, combo):
        return reduce(lambda x,y: x + y, (x * 10**i for i, x in enumerate(combo)))

    def get_config_id(self, deltas):
        """
        Return unique configuration id for a given set of deltas = {dname:value}
        the id is a digit where each position is the value of a delta
        """
        ordered_delta_values = [deltas[x] for x in self.fixed_delta_order]
        return self.get_config_id_from_combo(ordered_delta_values)

    def astep(self, q0):
        raise(NotImplementedError())

        # Seperate the discrete from the continuous parameters
        # Transform the dictionary q0.data 

        q0_deltas = {x:y for x,y in q0.data.items() if x in self.deltas}
        q0_ks     = {x:y for x,y in q0.data.items() if x in self.ks}

        # There are two categories, flip a delta, or step in the space only
        p_delta_flip = 0.1 # We should mostly stay within the space of interest TODO this should be tuned
        nb_deltas = len(q0_deltas)

        # decide if flipping a delta
        q_new_data = {x:y for x, y in q0.data}
        stats = {}
        if np.random.rand() < p_delta_flip:
            # Pick the delta to flip
            idx = random.choice(q0_deltas)

            # Generate the new point with dimension matching
            u = np.random.rand()
            q_new_data[idx] = 1 - q_new_data[idx]
            if q_new_data[idx] > 0:
                q_new_data[self.associations[idx]] = u
            else:
                q_new_data[self.associations[idx]] = 0

        else:
            # Nothing to be flipped, take a step within the current space 
            # Get only the value in the dimensions that matter
            selected_components = {self.associations[dname]:q_new[self.association[dname]] for dname,val in q0_deltas if val>0}
            q_partial, slice_stats = self.subspace_steppers[self.get_config_id(deltas)].step(selected_components)
            stats.update(slice_stats)
            q_new_data.update(q_partial)

        # Use metropolis select with precomputed acceptance probability
        # This is probably relying on a new call to likelihood, figure out how to do without it
        accept = self.delta_logp(q_new_data, q0.data)
        q_new_data, accepted = metrop_select(accept, q_new_data, q0.data)

        stats.update({"accepted":accepted
        , "accept": np.exp(accept)})

        return RaveledVars(q_new_data, q0.point_map_info), [stats]

I am not too familiar with the details of RJMCMC, but the overall logic seems correct to me.
As long as the astep is doing something like: astep(x_old: RaveledVars) -> RaveledVars, stats: you are on the right track.