Custom sampling step

I am not sure that approach was sound in V3, because the categorical draw would be updated in every inner_step of NUTS, and I imagine you want to hold that constant and only change between complete NUTS steps.

Yes, probably (EDIT: perhaps the modelling was straightforward enough, but that approach nevertheless delivered reasonable results as I examined them.)

You gave me plenty of meterial to improve the work. Looking forward to put it into practice.

1 Like

As an update: the method proposed by @ricardoV94 seems to work (though in my use case I do not see differences with other methods: same result and convergence time, and unfortunately, I still get this 2x or 3x slower execution in pymc v4 compared to pymc3). As I was trying to adapt the approach to pymc3 (to compare execution time and results, and to address the issues raised by @ricardoV94 that a) RandomStream will update the variable too often and b) a custom step with dummy Normal variables will invalidate the logp), I had to introduce a dummy variable for the SharedVariableUpdate step to be executed. Something like:

from pymc3.step_methods.arraystep import BlockedStep

class UpdateSharedVariablesStepV3(BlockedStep):

    def __init__(self, shared_vars, trace, seed=None):
        self.vars = shared_vars
        self.rng = np.random.default_rng(seed=seed)
        self.trace = trace

    def step(self, point: dict):
        i = self.rng.integers(100)
        #i = self.rng.integers(self.trace.posterior.draw.size)

        for v in self.vars:
            if v.name.startswith("dummy"): 
                points[v.name] = 0.  
                continue
            v.set_value(i)
            #v.set_value(self.trace.posterior[v.name].sel(draw=i, chain=0).values)

        return point

with pm.Model() as local_model:
   a = theano.shared(0.0, name='a')
   aT0 = theano.shared(0.0, name='aT0')
   dummy_step_variable = pm.Normal('dummy_step_variable', 0, 1) # does it also impact logp in any undesirable way?
   update_shared_step = UpdateSharedVariablesStepV3([a, aT0, dummy_step_variable], trace=None)
   ...
   local_trace = pm.sample(step=[update_shared_step], chains=1)

Hi @ricardoV94 ,
sorry to come back to this, but I am worried that the approach you proposed (to update a shared variable) does not work for multiple chains when multiple cores are used, not because of the seed, but likely some bug around multiprocessing. Additionally, there seem to be a bug when using the shared value as a mean to the Normal distribution (that is more mysterious to me). I made a minimal working example to demonstrate the issue:

import pymc as pm
import aesara
from pymc.step_methods.arraystep import BlockedStep

class UpdateSharedVarStep(BlockedStep):

    def __init__(self, shared_vars):
        self.vars = shared_vars

    def step(self, point: dict) -> dict:
        for v in self.vars:
            v.set_value(3)
        return point
    
with pm.Model():
    a = aesara.shared(0., name='a')
    
    pm.Deterministic('deterministic', a) # save in trace
    pm.Normal('normal', a, 1e-2)
    pm.Uniform('uniform', a, a+1e-2)
    
    fixed_step = UpdateSharedVarStep([a])
    
    a.set_value(22.)    
    trace = pm.sample(step=fixed_step, chains=1, tune=5, draws=5)
    print("1 chain (a set to 22; expected for all: 3): ")
    print("- deterministic:", trace.posterior.deterministic.values)
    print("- normal:", trace.posterior.normal.values)
    print("- uniform:", trace.posterior.uniform.values)
    print("- (shared:", a.get_value(),")")
        
    a.set_value(44.)
    trace = pm.sample(step=fixed_step, chains=2, tune=5, draws=5)
    print("2 chains (a set to 44; expected for all: 3): ")
    print("- deterministic:", trace.posterior.deterministic.values)
    print("- normal:", trace.posterior.normal.values)
    print("- uniform:", trace.posterior.uniform.values)
    print("- (shared:", a.get_value(),")")

Here the custom step simply set the shared variable value to 3, whatever happens, and whatever the initial value (22 or 44). So you’d expect all three variables “deterministic”, “normal”, and “uniform” would be exactly or close to 3. Now here are the results I get for one chain:

1 chain (a set to 22; expected for all: 3): 
- deterministic: [[3. 3. 3. 3. 3.]]
- normal: [[9.07357854 5.96475846 4.43589542 3.70905808 3.35497896]]
- uniform: [[3.00501058 3.00497676 3.00492586 3.00494255 3.00493957]]
- (shared: 3.0 )

As expected, for all but for the normal distribution (it varies too much considering the standard dev of 0.01). For 2 chains however:

2 chains (a set to 44; expected for all: 3): 
- deterministic: [[44. 44. 44. 44. 44.]
 [44. 44. 44. 44. 44.]]
- normal: [[ 5.2092079   4.08261882  3.54058747  3.27025367  3.13717144]
 [20.18981019 11.32478949  7.02944874  4.95407157  3.93772557]]
- uniform: [[44.00493403 44.00491381 44.00488669 44.00486748 44.00482371]
 [44.00500121 44.00498698 44.00499897 44.00498372 44.00500337]]
- (shared: 44.0 )

The shared value is not updated (it stays at 44). (and the normal distribution keeps going nuts). I checked the global sampling step is called, but the shared variable update does not seem to work.
I verified that settings cores=1 fixes the update issue (but not the normal distribution issue).

The package versions I use are the most recent I believe:
pymc : 4.2.0
aesara: 2.8.2

Yeah… shared variables are very trick with multiprocessing. They point to the same position in memory so the processes would conflict if they try to change it.

Sounds like a non shared input (e.g., x=at.scalar()) would be the easiest solution but I doubt that would work out of the box with the rest of PyMC logic.

Good enough that it works with cores=1. Just out of curiosity, would you spell out your thought a little, with at.scalar() ? I am not very familiar with theano/aesara symbolic variable definition. That far I only used that to test a define a function (like to test the scan machinery). But yeah, if it does not work no need to dwell into that. I’m alreay very happy with what we’ve reached so far.

Shared variables are just implicit (global) inputs, so the solution would be to use an implicit (local) input, but that’s not a thing. The next best option would be an explicit local input which is what stuff like at.scalar() are. The problem is that PyMC models don’t know about these and would not forward them to the logp/dlogp functions…

The closest thing is indeed to use at isolated Flat which won’t affect the logp and you can update safely in multiprocessing

import aesara
import aesara.tensor as at
import pymc as pm
import numpy as np

from pymc.step_methods.arraystep import BlockedStep

class UpdateFakeVariablesStep(BlockedStep):

    def __init__(self, fake_vars, trace, seed=None, model=None):
        model = pm.modelcontext(model)
        self.vars = [model.rvs_to_values.get(v, v) for v in fake_vars]
        self.rng = np.random.default_rng(seed=seed)
        self.trace = trace

    def step(self, point: dict):
        i = self.rng.integers(100)
        #i = self.rng.integers(self.trace.posterior.draw.size)

        for v in self.vars:
            #point[v.name] = self.trace.posterior[v.name].sel(draw=i, chain=0).value)
            point[v.name] = np.array(i)

        return point


temp = np.array([0, 1])
obs = np.array([0.3, 3.5])

with pm.Model() as local_model:
   a = pm.Flat('a')
   aT0 = pm.Flat('aT0')
   update_fake_step = UpdateFakeVariablesStep([a, aT0], trace=None)

   global_slr = pm.Deterministic('global_slr', a*temp[1]-aT0)
   land_motion = pm.Normal('land_motion', 0, 2)
   local_slr = pm.Deterministic('local_slr', global_slr + land_motion)

   pm.Normal('local_slr_obs', local_slr, sigma=0.2, observed=obs)

   local_trace = pm.sample(step=[update_fake_step], chains=2)
1 Like

The Normal is also fine, because it just adds a (variable) constant term in each step of NUTS, which won’t bias it. But using a Flat has the advantage that PyMC will not try to assume this is something that is not in other places (prior/posterior predictive, model comparison, etc…)

1 Like

That’s great.

I was about to ask exactly that, why the wrong logp was an issue, and I’m glad you confirm that it is not (only its gradient is). I agree that Flat is cleaner, as it seems to serve more explicitly as a dummy distribution (I was unaware of that distribution class).

Thanks much again @ricardoV94

Edit: Maybe you actually want to consider the fixed distribution logp in the rest of the model…

I started some discussion here: Add FixedDistribution helper ¡ Discussion #6275 ¡ pymc-devs/pymc ¡ GitHub

Ok thanks, I’ll keep an eye on it.

PS: in the code I ended up turning back to pm.Categorical() distribution (to index samples from a pre-existing trace), and an a custom step to handle it. That has the advantage that sample_prior_predictive works (whereas it cannot handle a Flat distribution). And since every draw is equiprobable (as of my setup), it should not influence the logp.

PPS: the main draw back of using an external step function, compared to previously RandomStream, is that no more than one core can be used, for some reason (otherwise the sampling freezes).

We should be able to navigate across that limitation. Might require digging a bit in how the multiprocessing is handled.

Can you share a minimum reproducible example that leads to “freezing”

Good point. It will take a little effort to pin down, but will do ASAP.