Hi @ckrapu, thanks for your response.
Yes, that notebook is one of the examples I came across. But I must admit that there is too much going on in it (the stick breaking, etc), and that I can’t figure out how to apply it to my problem. Is there somewhere a list of what keys the dict object must have? Or maybe an easier example?
As a MWE I can think of a linear fit:
import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
b = [2,1.5]
sigma = 2
n = 200
x = np.linspace(start=-20,stop = 20, num = n)
y = b[0]*x+b[1]
y_obs = y + sigma*np.random.randn(n)
And then I can use pymc3
like:
with pm.Model() as linear_model:
bm = pm.Normal("bm", mu=0, sigma=2,shape=2)
noise = pm.Gamma("noise", alpha=2, beta=1)
y_observed = pm.Normal("y_observed",mu=bm[0]*x+bm[1],sigma=noise,observed=y_obs)
posterior = pm.sample(model = linear_model,chains=2, draws=3000, tune=1000,return_inferencedata=True)
Now, if I want to sample b_m
, manually, I could define my own sampler
def my_sampler(mu=0,sigma=2):
b = np.random.normal(mu,sigma,2)
return b
and then construct a new model sort of like this (I kind of pieced this together from above link as well as this older example that I found:
class MySamplingStep(object):
def __init__(self, var, mu, sigma):
self.vars = var
self.mu = mu
self.sigma = sigma
def step(self, point: dict):
new = point.copy()
new[self.var] = my_sampler(self.mu,self.sigma)
return new
mu = 0
sigma = 2
with pm.Model() as my_linear_model:
bm = pm.Normal("bm", mu=0, sigma=2,shape=2)
step_bm = MySamplingStep(bm, mu, sigma)
noise = pm.Gamma("noise", alpha=2, beta=1)
y_observed = pm.Normal("y_observed",mu=bm[0]*x+bm[1],sigma=noise,observed=y_obs)
posterior2 = pm.sample(model = my_linear_model,step = [step_bm], chains=2, draws=3000, tune=1000,return_inferencedata=True)
Something like this, and then include the bm
step as step in pm.sample(...)
. But I still keep getting error messages… Length of bm ~ Normal cannot be determined
.
If you or someone could point me towards how to do that, that would be highly appreciated Once this example runs, I should be able to apply it to my actual problem.