import aesara
import aesara.tensor as at
import pymc as pm
import numpy as np
from pymc.step_methods.arraystep import BlockedStep
class UpdateFakeVariablesStep(BlockedStep):
def __init__(self, fake_vars, trace, seed=None, model=None):
model = pm.modelcontext(model)
self.vars = [model.rvs_to_values.get(v, v) for v in fake_vars]
self.rng = np.random.default_rng(seed=seed)
self.trace = trace
def step(self, point: dict):
i = self.rng.integers(100)
#i = self.rng.integers(self.trace.posterior.draw.size)
for v in self.vars:
#point[v.name] = self.trace.posterior[v.name].sel(draw=i, chain=0).value)
point[v.name] = np.array(i)
return point
temp = np.array([0, 1])
obs = np.array([0.3, 3.5])
with pm.Model() as local_model:
a = pm.Flat('a')
aT0 = pm.Flat('aT0')
update_fake_step = UpdateFakeVariablesStep([a, aT0], trace=None)
global_slr = pm.Deterministic('global_slr', a*temp[1]-aT0)
land_motion = pm.Normal('land_motion', 0, 2)
local_slr = pm.Deterministic('local_slr', global_slr + land_motion)
pm.Normal('local_slr_obs', local_slr, sigma=0.2, observed=obs)
local_trace = pm.sample(step=[update_fake_step], chains=2)
1 Like