Custom sampling step

import aesara
import aesara.tensor as at
import pymc as pm
import numpy as np

from pymc.step_methods.arraystep import BlockedStep

class UpdateFakeVariablesStep(BlockedStep):

    def __init__(self, fake_vars, trace, seed=None, model=None):
        model = pm.modelcontext(model)
        self.vars = [model.rvs_to_values.get(v, v) for v in fake_vars]
        self.rng = np.random.default_rng(seed=seed)
        self.trace = trace

    def step(self, point: dict):
        i = self.rng.integers(100)
        #i = self.rng.integers(self.trace.posterior.draw.size)

        for v in self.vars:
            #point[v.name] = self.trace.posterior[v.name].sel(draw=i, chain=0).value)
            point[v.name] = np.array(i)

        return point


temp = np.array([0, 1])
obs = np.array([0.3, 3.5])

with pm.Model() as local_model:
   a = pm.Flat('a')
   aT0 = pm.Flat('aT0')
   update_fake_step = UpdateFakeVariablesStep([a, aT0], trace=None)

   global_slr = pm.Deterministic('global_slr', a*temp[1]-aT0)
   land_motion = pm.Normal('land_motion', 0, 2)
   local_slr = pm.Deterministic('local_slr', global_slr + land_motion)

   pm.Normal('local_slr_obs', local_slr, sigma=0.2, observed=obs)

   local_trace = pm.sample(step=[update_fake_step], chains=2)
1 Like