Hi there,
I’m implementing a simple Nested Monte Carlo (NMC) estimator for the Expected Information Gain (EIG) for a logistic regression model. That is, I want to find the expected decrease in uncertainty for a candidate next sampling point (or design) d :
This is intractable, but can be estimated using this NMC estimator (see [1903.05480] Variational Bayesian Optimal Experimental Design):
There already is a Pyro implementation, but Pyro/PyTorch is not great for actual posterior sampling (their NUTS implementation is quite slow), so I’d like to have something similar in PyMC.
My code so far works, but it’s excruciatingly slow (mostly because of the repeated draw calls I suppose). What would be the PyMC way of writing this in a more efficient way?
import scipy.special as sp
from pymc import draw
from pymc import Normal, Bernoulli, Uniform, Model, sample
import numpy as np
# true parameters (not used in the model)
true_k = 30
true_x0 = 0.5
# plotting domain
xmin = -0.3
xmax = 1.0
# priors
k_mu, k_sigma = 25.0, 1.0
x0_mu, x0_sigma = 1.0, 1.5
# parameters for the Nested Monte Carlo
N = 200
M = 50
candidates = np.linspace(xmin, xmax, 10)
with Model() as model:
# priors
k = Normal('k', mu=k_mu, sigma=k_sigma)
x0 = Normal('x0', mu=x0_mu, sigma=x0_sigma)
# the designs
d = pymc.ConstantData('d', candidates)
# the observed data
observed_y = pymc.MutableData('observed_y', np.zeros(10))
# the likelihood
y = Bernoulli('y', logit_p=k * (d - x0), observed=observed_y)
# log probabilities
lp = pymc.Deterministic('lp', pymc.logp(y, y))
lp_cond = pymc.Deterministic(
'lp_cond',
pymc.logprob.conditional_logp({
y: observed_y,
})[observed_y])
# sampling
conditional_lp = draw(lp, N)
observed_data = draw(y, N)
# sample the data
marginal_lp = np.zeros((M, N, len(candidates)))
for n in range(N):
observed_y.set_value(observed_data[n])
marginal_lp[:, n, :] = draw(lp_cond, M)
marginal_lp = sp.logsumexp(marginal_lp, axis=0) - np.log(M)
# compute the expected information gain for every design
eig = (conditional_lp - marginal_lp).sum(axis=0) / N
print(eig)