I am not sure that approach was sound in V3, because the categorical draw would be updated in every inner_step of NUTS, and I imagine you want to hold that constant and only change between complete NUTS steps.
Yes, probably (EDIT: perhaps the modelling was straightforward enough, but that approach nevertheless delivered reasonable results as I examined them.)
You gave me plenty of meterial to improve the work. Looking forward to put it into practice.
As an update: the method proposed by @ricardoV94 seems to work (though in my use case I do not see differences with other methods: same result and convergence time, and unfortunately, I still get this 2x or 3x slower execution in pymc v4 compared to pymc3). As I was trying to adapt the approach to pymc3 (to compare execution time and results, and to address the issues raised by @ricardoV94 that a) RandomStream will update the variable too often and b) a custom step with dummy Normal variables will invalidate the logp), I had to introduce a dummy variable for the SharedVariableUpdate step to be executed. Something like:
from pymc3.step_methods.arraystep import BlockedStep
class UpdateSharedVariablesStepV3(BlockedStep):
def __init__(self, shared_vars, trace, seed=None):
self.vars = shared_vars
self.rng = np.random.default_rng(seed=seed)
self.trace = trace
def step(self, point: dict):
i = self.rng.integers(100)
#i = self.rng.integers(self.trace.posterior.draw.size)
for v in self.vars:
if v.name.startswith("dummy"):
points[v.name] = 0.
continue
v.set_value(i)
#v.set_value(self.trace.posterior[v.name].sel(draw=i, chain=0).values)
return point
with pm.Model() as local_model:
a = theano.shared(0.0, name='a')
aT0 = theano.shared(0.0, name='aT0')
dummy_step_variable = pm.Normal('dummy_step_variable', 0, 1) # does it also impact logp in any undesirable way?
update_shared_step = UpdateSharedVariablesStepV3([a, aT0, dummy_step_variable], trace=None)
...
local_trace = pm.sample(step=[update_shared_step], chains=1)
Hi @ricardoV94 ,
sorry to come back to this, but I am worried that the approach you proposed (to update a shared variable) does not work for multiple chains when multiple cores are used, not because of the seed, but likely some bug around multiprocessing. Additionally, there seem to be a bug when using the shared value as a mean to the Normal distribution (that is more mysterious to me). I made a minimal working example to demonstrate the issue:
import pymc as pm
import aesara
from pymc.step_methods.arraystep import BlockedStep
class UpdateSharedVarStep(BlockedStep):
def __init__(self, shared_vars):
self.vars = shared_vars
def step(self, point: dict) -> dict:
for v in self.vars:
v.set_value(3)
return point
with pm.Model():
a = aesara.shared(0., name='a')
pm.Deterministic('deterministic', a) # save in trace
pm.Normal('normal', a, 1e-2)
pm.Uniform('uniform', a, a+1e-2)
fixed_step = UpdateSharedVarStep([a])
a.set_value(22.)
trace = pm.sample(step=fixed_step, chains=1, tune=5, draws=5)
print("1 chain (a set to 22; expected for all: 3): ")
print("- deterministic:", trace.posterior.deterministic.values)
print("- normal:", trace.posterior.normal.values)
print("- uniform:", trace.posterior.uniform.values)
print("- (shared:", a.get_value(),")")
a.set_value(44.)
trace = pm.sample(step=fixed_step, chains=2, tune=5, draws=5)
print("2 chains (a set to 44; expected for all: 3): ")
print("- deterministic:", trace.posterior.deterministic.values)
print("- normal:", trace.posterior.normal.values)
print("- uniform:", trace.posterior.uniform.values)
print("- (shared:", a.get_value(),")")
Here the custom step simply set the shared variable value to 3, whatever happens, and whatever the initial value (22 or 44). So youâd expect all three variables âdeterministicâ, ânormalâ, and âuniformâ would be exactly or close to 3. Now here are the results I get for one chain:
1 chain (a set to 22; expected for all: 3):
- deterministic: [[3. 3. 3. 3. 3.]]
- normal: [[9.07357854 5.96475846 4.43589542 3.70905808 3.35497896]]
- uniform: [[3.00501058 3.00497676 3.00492586 3.00494255 3.00493957]]
- (shared: 3.0 )
As expected, for all but for the normal distribution (it varies too much considering the standard dev of 0.01). For 2 chains however:
2 chains (a set to 44; expected for all: 3):
- deterministic: [[44. 44. 44. 44. 44.]
[44. 44. 44. 44. 44.]]
- normal: [[ 5.2092079 4.08261882 3.54058747 3.27025367 3.13717144]
[20.18981019 11.32478949 7.02944874 4.95407157 3.93772557]]
- uniform: [[44.00493403 44.00491381 44.00488669 44.00486748 44.00482371]
[44.00500121 44.00498698 44.00499897 44.00498372 44.00500337]]
- (shared: 44.0 )
The shared value is not updated (it stays at 44). (and the normal distribution keeps going nuts). I checked the global sampling step is called, but the shared variable update does not seem to work.
I verified that settings cores=1
fixes the update issue (but not the normal distribution issue).
The package versions I use are the most recent I believe:
pymc : 4.2.0
aesara: 2.8.2
Yeah⌠shared variables are very trick with multiprocessing. They point to the same position in memory so the processes would conflict if they try to change it.
Sounds like a non shared input (e.g., x=at.scalar()
) would be the easiest solution but I doubt that would work out of the box with the rest of PyMC logic.
Good enough that it works with cores=1. Just out of curiosity, would you spell out your thought a little, with at.scalar() ? I am not very familiar with theano/aesara symbolic variable definition. That far I only used that to test a define a function
(like to test the scan
machinery). But yeah, if it does not work no need to dwell into that. Iâm alreay very happy with what weâve reached so far.
Shared variables are just implicit (global) inputs, so the solution would be to use an implicit (local) input, but thatâs not a thing. The next best option would be an explicit local input which is what stuff like at.scalar()
are. The problem is that PyMC models donât know about these and would not forward them to the logp/dlogp functionsâŚ
The closest thing is indeed to use at isolated Flat
which wonât affect the logp and you can update safely in multiprocessing
import aesara
import aesara.tensor as at
import pymc as pm
import numpy as np
from pymc.step_methods.arraystep import BlockedStep
class UpdateFakeVariablesStep(BlockedStep):
def __init__(self, fake_vars, trace, seed=None, model=None):
model = pm.modelcontext(model)
self.vars = [model.rvs_to_values.get(v, v) for v in fake_vars]
self.rng = np.random.default_rng(seed=seed)
self.trace = trace
def step(self, point: dict):
i = self.rng.integers(100)
#i = self.rng.integers(self.trace.posterior.draw.size)
for v in self.vars:
#point[v.name] = self.trace.posterior[v.name].sel(draw=i, chain=0).value)
point[v.name] = np.array(i)
return point
temp = np.array([0, 1])
obs = np.array([0.3, 3.5])
with pm.Model() as local_model:
a = pm.Flat('a')
aT0 = pm.Flat('aT0')
update_fake_step = UpdateFakeVariablesStep([a, aT0], trace=None)
global_slr = pm.Deterministic('global_slr', a*temp[1]-aT0)
land_motion = pm.Normal('land_motion', 0, 2)
local_slr = pm.Deterministic('local_slr', global_slr + land_motion)
pm.Normal('local_slr_obs', local_slr, sigma=0.2, observed=obs)
local_trace = pm.sample(step=[update_fake_step], chains=2)
The Normal is also fine, because it just adds a (variable) constant term in each step of NUTS, which wonât bias it. But using a Flat has the advantage that PyMC will not try to assume this is something that is not in other places (prior/posterior predictive, model comparison, etcâŚ)
Thatâs great.
I was about to ask exactly that, why the wrong logp was an issue, and Iâm glad you confirm that it is not (only its gradient is). I agree that Flat
is cleaner, as it seems to serve more explicitly as a dummy distribution (I was unaware of that distribution class).
Thanks much again @ricardoV94
Edit: Maybe you actually want to consider the fixed distribution logp in the rest of the modelâŚ
I started some discussion here: Add FixedDistribution helper ¡ Discussion #6275 ¡ pymc-devs/pymc ¡ GitHub
Ok thanks, Iâll keep an eye on it.
PS: in the code I ended up turning back to pm.Categorical()
distribution (to index samples from a pre-existing trace), and an a custom step to handle it. That has the advantage that sample_prior_predictive works (whereas it cannot handle a Flat distribution). And since every draw is equiprobable (as of my setup), it should not influence the logp.
PPS: the main draw back of using an external step function, compared to previously RandomStream, is that no more than one core can be used, for some reason (otherwise the sampling freezes).
We should be able to navigate across that limitation. Might require digging a bit in how the multiprocessing is handled.
Can you share a minimum reproducible example that leads to âfreezingâ
Good point. It will take a little effort to pin down, but will do ASAP.