Compiling the function does help, but you still end up with lots of repeated invocations which is probably not ideal (haven’t benchmarked that yet it’s unbearably slow).
I tried batching the other variables as well, and the code below seems to work now. Hoewever, it’s still roughly ~70% slower than the Pyro implementation. I’ll need to check what takes so long, but it’s also causign problems with actually sampling from the posterior.
with Model() as model:
# the observed data
observed_y = pymc.MutableData('observed_y', np.zeros((D, 2)))
# the designs
d = pymc.MutableData('d', np.ones((D, 2)))
# priors
k = Normal('k', mu=k_mu, sigma=k_sigma, shape=d.shape, dims='data')
x0 = Normal('x0', mu=x0_mu, sigma=x0_sigma, shape=d.shape, dims='data')
# parameters
logit = pymc.Deterministic('logit', k * (d - x0))
# the likelihood
y = Bernoulli('y',
logit_p=logit,
observed=observed_y,
shape=d.shape,
dims='data')
# log probabilities
lp = pymc.Deterministic('lp', pymc.logp(y, y))
lp_cond = pymc.Deterministic(
'lp_cond',
pymc.logprob.conditional_logp({
y: observed_y,
})[observed_y])
# compile the draw functions
cdraw1 = compile_pymc(inputs=[], outputs=[lp, y])
new_observed_y, new_d = pytensor.tensor.matrix('new_observed_y'), pytensor.tensor.matrix('new_d')
new_lp_cond = vectorize_graph(lp_cond, {observed_y: new_observed_y, d: new_d})
cdraw2 = compile_pymc(inputs=[new_observed_y, new_d], outputs=new_lp_cond)
# set the designs (candoate sampling points)
d.set_value(np.repeat(np.linspace(xmin, xmax, D)[:, np.newaxis], N, axis=1))
# sampling
conditional_lp, observed_data = cdraw1()
observed_data = np.repeat(observed_data[:, :, np.newaxis], M, axis=2)
new_d_value = np.repeat(np.linspace(xmin, xmax, D)[:, np.newaxis], M, axis=1)
# sample the data
marginal_lp = np.ones((M, N, D))
for n in range(N):
marginal_lp[:, n, :] = cdraw2(observed_data[:, n], new_d_value).T
marginal_lp = sp.logsumexp(marginal_lp, axis=0) - np.log(M)
# compute the expected information gain for every design
eig = (conditional_lp.T - marginal_lp).sum(axis=0) / N