@opherdonchin PyMC can compute that, but you are right that most step samplers don’t exploit this directly. When you assign a variable to a step sampler like NUTS
or Slice
it computes a model logp function that has all the terms, even those that won’t change when the variable changes:
(NUTS will also compute the dlogp
, and that is only for the variables that are being updated)
In contrast, some step samplers, like Metropolis
ask for a delta_logp
function, which is logp(x1) - logp(x2). In this case the PyTensor backend will automatically remove terms that are identical between the two sides, only keeping terms that do depend on the changing variable:
The only place I am aware we explicitly exploit conditional dependencies in in MarginalModel
, where we compute the logp of a marginalized discrete variable by only considering variables in the markov blanket for efficiency purposes: pymc-experimental/pymc_experimental/model/marginal/marginal_model.py at 4deeec6490c755ff77ae6f79d20d88f233e508e1 · pymc-devs/pymc-experimental · GitHub
We could probably exploit this in more places, but historically we didn’t because 1) most times we just use a single sampler to update all variables and 2) we didn’t have a good internal representation in previous versions of PyMC (< 4.0)
Here is a minimal example of the delta_logp
:
import pymc as pm
from pytensor.compile.mode import get_mode
with pm.Model() as m:
x = pm.Normal("x")
y = pm.Normal("y")
logp1 = m.logp(vars=[x, y])
logp2 = m.logp(vars=[x])
mode = get_mode("FAST_RUN").excluding("fusion") # for readability
m.compile_fn(logp1, mode=mode).f.dprint()
# Add [id A] '__logp' 4
# ├─ -1.8378770664093453 [id B]
# ├─ Mul [id C] 3
# │ ├─ -0.5 [id D]
# │ └─ Sqr [id E] 2
# │ └─ x [id F]
# └─ Mul [id G] 1
# ├─ -0.5 [id D]
# └─ Sqr [id H] 0
# └─ y [id I]
m.compile_fn(logp1 - logp2, mode=mode).f.dprint()
# Sub [id A] 'sigma > 0' 2
# ├─ Mul [id B] 1
# │ ├─ -0.5 [id C]
# │ └─ Sqr [id D] 0
# │ └─ y [id E]
# └─ 0.9189385332046727 [id F]
Notice x is not part of the second function