Faster Nested Monte Carlo Estimator for EIG

Hi there,

I’m implementing a simple Nested Monte Carlo (NMC) estimator for the Expected Information Gain (EIG) for a logistic regression model. That is, I want to find the expected decrease in uncertainty for a candidate next sampling point (or design) d :

\operatorname{EIG}(d) \triangleq \mathbb{E}_{p(y \mid d)}[H[p(\theta)] - H[p(\theta \mid y, d)]]

This is intractable, but can be estimated using this NMC estimator (see [1903.05480] Variational Bayesian Optimal Experimental Design):

\frac{1}{N}\sum_{n=1}^N \log p(y_n \mid \theta_n, d) - \frac{1}{N}\sum_{n=1}^N \log \left(\frac{1}{M}\sum_{m=1}^M p(y_n \mid \theta_m, d)\right)

There already is a Pyro implementation, but Pyro/PyTorch is not great for actual posterior sampling (their NUTS implementation is quite slow), so I’d like to have something similar in PyMC.

My code so far works, but it’s excruciatingly slow (mostly because of the repeated draw calls I suppose). What would be the PyMC way of writing this in a more efficient way?

import scipy.special as sp
from pymc import draw
from pymc import Normal, Bernoulli, Uniform, Model, sample
import numpy as np

# true parameters (not used in the model)
true_k = 30
true_x0 = 0.5

# plotting domain
xmin = -0.3
xmax = 1.0

# priors
k_mu, k_sigma = 25.0, 1.0
x0_mu, x0_sigma = 1.0, 1.5

# parameters for the Nested Monte Carlo
N = 200
M = 50

candidates = np.linspace(xmin, xmax, 10)

with Model() as model:
    # priors
    k = Normal('k', mu=k_mu, sigma=k_sigma)
    x0 = Normal('x0', mu=x0_mu, sigma=x0_sigma)

    # the designs
    d = pymc.ConstantData('d', candidates)

    # the observed data
    observed_y = pymc.MutableData('observed_y', np.zeros(10))

    # the likelihood
    y = Bernoulli('y', logit_p=k * (d - x0), observed=observed_y)

    # log probabilities
    lp = pymc.Deterministic('lp', pymc.logp(y, y))
    lp_cond = pymc.Deterministic(
        'lp_cond',
        pymc.logprob.conditional_logp({
            y: observed_y,
        })[observed_y])

# sampling
conditional_lp = draw(lp, N)
observed_data = draw(y, N)

# sample the data
marginal_lp = np.zeros((M, N, len(candidates)))

for n in range(N):
    observed_y.set_value(observed_data[n])
    marginal_lp[:, n, :] = draw(lp_cond, M)

marginal_lp = sp.logsumexp(marginal_lp, axis=0) - np.log(M)

# compute the expected information gain for every design
eig = (conditional_lp - marginal_lp).sum(axis=0) / N

print(eig)

You could try to use vectorize_graph to create a batched version of lp_cond that handles all the values of observed_y at once.

Recommend trying with the latest version of PyMC

That looks promising, thanks! But how would that work with calling draw()?

Like I can create a new Variable with

new_observed_y = pytensor.tensor.matrix('new_observed_y')
new_lp_cond = vectorize_graph(lp_cond, {observed_y: new_observed_y})

but how do I actually sample from that?

Edit: Okay, I can do

new_observed_y = pytensor.tensor.matrix('new_observed_y')
new_lp_cond = vectorize_graph(lp_cond, {observed_y: new_observed_y})

cdraw = compile_pymc(inputs=[new_observed_y], outputs=new_lp_cond)

cdraw(np.zeros((2, 10)))

But it looks like this method doesn’t sample each prior value individually. It grabs just one set of x0 and k values and uses them for all elements. Is there a way to make sure each element in the input gets its own “fresh” set of samples (like with the draws paramter in pymc.draw())?

I’ll have to think a bit about what the issue is, maybe you need to batch the root RVs as well, but first why can’t you use pm.draw on the vectorized output?

Anyway if this doesn’t work you can just compile the function generated by draw once and reuse it. It should be faster than calling pm.draw in a loop which always recompiles.

Compiling the function does help, but you still end up with lots of repeated invocations which is probably not ideal (haven’t benchmarked that yet it’s unbearably slow).

I tried batching the other variables as well, and the code below seems to work now. Hoewever, it’s still roughly ~70% slower than the Pyro implementation. I’ll need to check what takes so long, but it’s also causign problems with actually sampling from the posterior.

with Model() as model:

    # the observed data
    observed_y = pymc.MutableData('observed_y', np.zeros((D, 2)))

    # the designs
    d = pymc.MutableData('d', np.ones((D, 2)))

    # priors
    k = Normal('k', mu=k_mu, sigma=k_sigma, shape=d.shape, dims='data')
    x0 = Normal('x0', mu=x0_mu, sigma=x0_sigma, shape=d.shape, dims='data')

    # parameters
    logit = pymc.Deterministic('logit', k * (d - x0))

    # the likelihood
    y = Bernoulli('y',
                  logit_p=logit,
                  observed=observed_y,
                  shape=d.shape,
                  dims='data')

    # log probabilities
    lp = pymc.Deterministic('lp', pymc.logp(y, y))
    lp_cond = pymc.Deterministic(
        'lp_cond',
        pymc.logprob.conditional_logp({
            y: observed_y,
        })[observed_y])

# compile the draw functions
cdraw1 = compile_pymc(inputs=[], outputs=[lp, y])

new_observed_y, new_d = pytensor.tensor.matrix('new_observed_y'), pytensor.tensor.matrix('new_d')
new_lp_cond = vectorize_graph(lp_cond, {observed_y: new_observed_y, d: new_d})
cdraw2 = compile_pymc(inputs=[new_observed_y, new_d], outputs=new_lp_cond)

# set the designs (candoate sampling points)
d.set_value(np.repeat(np.linspace(xmin, xmax, D)[:, np.newaxis], N, axis=1))

# sampling
conditional_lp, observed_data = cdraw1()
observed_data = np.repeat(observed_data[:, :, np.newaxis], M, axis=2)
new_d_value = np.repeat(np.linspace(xmin, xmax, D)[:, np.newaxis], M, axis=1)

# sample the data
marginal_lp = np.ones((M, N, D))

for n in range(N):
    marginal_lp[:, n, :] = cdraw2(observed_data[:, n], new_d_value).T

marginal_lp = sp.logsumexp(marginal_lp, axis=0) - np.log(M)

# compute the expected information gain for every design
eig = (conditional_lp.T - marginal_lp).sum(axis=0) / N