Faster Nested Monte Carlo Estimator for EIG

Compiling the function does help, but you still end up with lots of repeated invocations which is probably not ideal (haven’t benchmarked that yet it’s unbearably slow).

I tried batching the other variables as well, and the code below seems to work now. Hoewever, it’s still roughly ~70% slower than the Pyro implementation. I’ll need to check what takes so long, but it’s also causign problems with actually sampling from the posterior.

with Model() as model:

    # the observed data
    observed_y = pymc.MutableData('observed_y', np.zeros((D, 2)))

    # the designs
    d = pymc.MutableData('d', np.ones((D, 2)))

    # priors
    k = Normal('k', mu=k_mu, sigma=k_sigma, shape=d.shape, dims='data')
    x0 = Normal('x0', mu=x0_mu, sigma=x0_sigma, shape=d.shape, dims='data')

    # parameters
    logit = pymc.Deterministic('logit', k * (d - x0))

    # the likelihood
    y = Bernoulli('y',
                  logit_p=logit,
                  observed=observed_y,
                  shape=d.shape,
                  dims='data')

    # log probabilities
    lp = pymc.Deterministic('lp', pymc.logp(y, y))
    lp_cond = pymc.Deterministic(
        'lp_cond',
        pymc.logprob.conditional_logp({
            y: observed_y,
        })[observed_y])

# compile the draw functions
cdraw1 = compile_pymc(inputs=[], outputs=[lp, y])

new_observed_y, new_d = pytensor.tensor.matrix('new_observed_y'), pytensor.tensor.matrix('new_d')
new_lp_cond = vectorize_graph(lp_cond, {observed_y: new_observed_y, d: new_d})
cdraw2 = compile_pymc(inputs=[new_observed_y, new_d], outputs=new_lp_cond)

# set the designs (candoate sampling points)
d.set_value(np.repeat(np.linspace(xmin, xmax, D)[:, np.newaxis], N, axis=1))

# sampling
conditional_lp, observed_data = cdraw1()
observed_data = np.repeat(observed_data[:, :, np.newaxis], M, axis=2)
new_d_value = np.repeat(np.linspace(xmin, xmax, D)[:, np.newaxis], M, axis=1)

# sample the data
marginal_lp = np.ones((M, N, D))

for n in range(N):
    marginal_lp[:, n, :] = cdraw2(observed_data[:, n], new_d_value).T

marginal_lp = sp.logsumexp(marginal_lp, axis=0) - np.log(M)

# compute the expected information gain for every design
eig = (conditional_lp.T - marginal_lp).sum(axis=0) / N