Custom sampling step

Hi @ckrapu, thanks for your response.

Yes, that notebook is one of the examples I came across. But I must admit that there is too much going on in it (the stick breaking, etc), and that I can’t figure out how to apply it to my problem. Is there somewhere a list of what keys the dict object must have? Or maybe an easier example?

As a MWE I can think of a linear fit:

import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

b = [2,1.5]
sigma = 2

n = 200
x = np.linspace(start=-20,stop = 20, num = n)
y = b[0]*x+b[1]
y_obs = y + sigma*np.random.randn(n)

image
And then I can use pymc3 like:

with pm.Model() as linear_model:
    bm = pm.Normal("bm", mu=0, sigma=2,shape=2)
    noise = pm.Gamma("noise", alpha=2, beta=1)
    
    y_observed = pm.Normal("y_observed",mu=bm[0]*x+bm[1],sigma=noise,observed=y_obs)

posterior = pm.sample(model = linear_model,chains=2, draws=3000, tune=1000,return_inferencedata=True)

Now, if I want to sample b_m, manually, I could define my own sampler

def my_sampler(mu=0,sigma=2):
    b = np.random.normal(mu,sigma,2)
    return b

and then construct a new model sort of like this (I kind of pieced this together from above link as well as this older example that I found:

class MySamplingStep(object):
    def __init__(self, var, mu, sigma):
            self.vars = var
            self.mu = mu
            self.sigma = sigma

    def step(self, point: dict):
        new = point.copy()
        new[self.var] = my_sampler(self.mu,self.sigma)

        return new

mu = 0
sigma = 2

with pm.Model() as my_linear_model:
    bm = pm.Normal("bm", mu=0, sigma=2,shape=2)
    step_bm = MySamplingStep(bm, mu, sigma)
    noise = pm.Gamma("noise", alpha=2, beta=1)

    y_observed = pm.Normal("y_observed",mu=bm[0]*x+bm[1],sigma=noise,observed=y_obs)

posterior2 = pm.sample(model = my_linear_model,step = [step_bm], chains=2, draws=3000, tune=1000,return_inferencedata=True)

Something like this, and then include the bm step as step in pm.sample(...). But I still keep getting error messages… Length of bm ~ Normal cannot be determined.

If you or someone could point me towards how to do that, that would be highly appreciated :slight_smile: Once this example runs, I should be able to apply it to my actual problem.