Conditionally reject samples

Hi, i’m trying to implement conditioning in PyMC3. Let’s say you make X = N(-10,4) then calculate Y as X*X, then conditionally reject all samples for Y>10 to infer (X|Y<10). It seems that the only way to manipulate likelihood is using pm.Potential. So i’m doing something like that:

with pm.Model() as model:
  x = pm.Normal('x', mu=-10, sd=5)
  y = pm.Deterministic('y', x * x)
  pm.Potential('cond', pm.math.switch(y < 10, 0, -np.inf))
  result = pm.sample(model=model, step=pm.Metropolis(), draws=1000)
pm.traceplot(result, varnames=['x']);

What sampler should I use? Metropolis always gives -10. NUTS shows SamplingError: Bad initial energy
Maybe there’s another correct way to implement that?

If I just use a negative number of large magnitude NUTS ignores the condition

MH produces correct samples

NUTS starts working only with relatively small logp values

You can hack this with observed on a Bernoulli distribution:

import pymc3 as pm
import theano
import theano.tensor as tt
import numpy as np

with pm.Model() as model:
    x = pm.Normal('x', mu=-10, sd=5.)
    y = pm.Deterministic('y', x**2)
    ind_sw = pm.Deterministic('ind_p', 1. * (y > 10.))
    ind = pm.Bernoulli('y_gt_10', ind_sw, observed=1)
    tr = pm.sample(500, tune=1000, cores=2, chains=10)

pm.traceplot(tr, 'x');

and for observed=0:

2 Likes

Thank you!
I’ve tried changing (y > 10.) to (y < 10.) or (y > 3. and y < 10.)
In both cases same Bad initial energy error.

I’m using Colab

Metropolis also failed:

Try x = pm.Normal('x', mu=-10, sd=5., testval=1.)

Thanks, but I’d like to find a general solution for conditioning in PyMC3 and wrap it in a helper function to make it easier. To have something like condition function in WebPPL. Calculating initial points for a markov chain could be quite hard for more complex models.

Conditioning on deterministics in general is quite hard, and I think with PyMC3 you would have to do some math on your end for that. Note that pm.Bound exists to bound a distribution like pm.Normal (and pm.TruncatedNormal, which is way more efficient!). I believe most PPLs that support a general conditioning statement do so using rejection sampling, which may be disastrously inefficient.

Here’s your original example:

with pm.Model() as model:
    x = pm.TruncatedNormal('x', mu=-10, sd=5, lower=-(10**0.5), upper=10**0.5)
    y = pm.Deterministic('y', x * x)
    result = pm.sample()

pm.traceplot(result)

1 Like

If you think about it, you can’t avoid rejection sampling if you

  1. allow conditions to completely restrict the sample space
  2. restrict yourself to not specifying initial points

because there is no guarantee that some later-applied condition won’t conflict with the initial point (or, indeed, with each other!). If you want the speedup of a higher-order sampling method you have to live with loosening either (1) or (2); perhaps the simplest way would be a hybrid approach:

EPSILON = 0.01
with pm.Model() as model_approx:
    x = pm.Normal('x', mu=-10., sd=5.)
    y = pm.Deterministic('y', x**2)
    ind_sw = pm.Deterministic('ind_p', 1. * (y > 10.))
    ind = pm.Bernoulli('y_gt_10', (1 - EPSILON) * ind_sw, observed=0)
    tr = pm.sample(100, tune=100, cores=1, chains=1, nuts_kwargs={'max_treedepth': 20})

pm.traceplot(tr, 'x');
probable_near_mode = np.random.choice(tr['x'])

with pm.Model() as model:
    x = pm.Normal('x', mu=-10, sd=5., testval=probable_near_mode)
    y = pm.Deterministic('y', x**2)
    ind_sw = pm.Deterministic('ind_p', 1. * (y > 10.))
    ind = pm.Bernoulli('y_gt_10', ind_sw, observed=0)
    tr2 = pm.sample(500, tune=1000, cores=2, chains=10)


pm.traceplot(tr2, 'x');

Even this approach is unlikely to work if you condition on a set of measure 0, such as P(x|y=8), but then again neither will rejection sampling.

1 Like

Thank you! I understand that both algorithms (nuts, metropolis) fail when their starting point is located in a flat region of 0 prob. It’s probably because MH rejects all steps (neighbors has 0 prob). NUTS fails because it needs gradients to find and explore a typical set.

You are 100% right saying that rejection sampling is not efficient, especially in higher dimensions. However, i believe it’s better to have it as an option for a sampler.

  1. Conditioning is one of the basic inference techniques. Bounding initial variables manually is possible only in such artificial examples like ‘X ** 2’. What would you do if we condition on something like ‘X1 * X2 > X3’. If rejection sampling is the only way to make such inference, we need it.

  2. It’s probably possible to use the rejection sampling to find regions of non-zero probability, then use Metropolis or other step algorithms to explore them, if multiple starting points are generated there’s a high chance of exploring all needed regions.

It seems that Metropolis somehow produces samples in WebPPL for such problems. Auto-covariance is huge, but samples are not totally random compared to Rejection Sampling
Metropolis (lag 100):

I’ll check how Stan and Pyro deal with such problems

2 Likes

As an open source library, pull requests are available! Maybe it would be efficient to build a separate library on top of pymc3 using something like tempering or simulated annealing. Like, for t = np.linspace(0, 1), sample from t * pdf(x) + (1 - t) * pdf(x) * cond(x).

Turning that into a valid and efficient MCMC strategy is tricky!

2 Likes

Perhaps there’s a better place to continue this discussion, but how often is it the case that for a condition (event) set \mathcal{A} \subseteq \mathcal{X} you can define some differentiable metric d(x, \mathcal{A})? This would enable annealing on a Gibbs potential

\ell \propto \mathrm{pdf}(x)\prod_{j=1}^k \exp\left\{-\tau \times d(x, \mathcal{A}_j)\right\}

which would give gradients towards the condition set for higher-order methods.

@zemlyansky This is probably a better way to use pm.Potential too (I find that this works about 2 orders of magnitude faster than a flat potential: 382.20it/s vs 18.93 it/s

tau = 5.
with pm.Model() as model:
  x = pm.Normal('x', mu=-10, sd=5)
  y = pm.Deterministic('y', x * x)
  d = pm.Deterministic('d', (y - 10) ** 2)
  pm.Potential('cond', pm.math.switch(y < 10, 0, -tau * d))
  #result = pm.sample(model=model, step=pm.Metropolis(), draws=1000)
  result = pm.sample(1000, tune=100, chains=3, cores=1, nuts_kwargs={'target_accept':0.95})
pm.traceplot(result, varnames=['x']);

only ~2% of samples violate the condition, and the maximum value is y=10.764.... Increasing \tau too much, as you noticed, tends to break the sampler; but if the number of violations is small they can be fixed by a quick rejection pass at the very end.

@colcarroll What’s missing for this approach is a callback in the sampler to do things like change tau during tuning to ramp it to its final value.

1 Like

I am afraid I have lost the thread/lack the necessary background. Simulated annealing is a topic on which I have been looking for a good reference for a few years now…

That is an interesting idea. Could we hack it by training a model in a loop, varying t, and using trace[0] as a start position for the next sampler?

I have been thinking about a more structured warmup phase for PyMC4 right now. Stan have “fast” (I, III) and “slow” (II) intervals, which roughly correspond to step size adaptation and mass matrix adaptation:
image
(from https://mc-stan.org/docs/2_19/reference-manual/hmc-algorithm-parameters.html)

On the one hand, it would be neat to formalize these phases, and provide a public API for researchers/users to implement custom phases that tune training parameters. On the other hand, I want to see how complex this makes the warmup system.

1 Like