I am trying to find physical conditions in a medium affected by many complicated processes. The physics are time-consuming enough to calculate that I can’t resimulate for each step in an MCMC, for instance. Most people in my field choose to assume the observables vary slowly with respect to the underlying physical conditions, and simulate a relatively sparse grid of observables. I am used to wrapping that grid with an interpolator, but there is no native interpolation in Theano, and I want to be able to take advantage of NUTS without having to write gradients and such.
Prompted by these two threads [1, 2], I have reimplemented a handful of sparsely-sampled N-dimensional function grid interpolations as a set of GPs. I can solve a system of this form for each observable:
def solve_single_gp(lengthscales, X0, samples, noise=1.0e-3):
cov = pm.gp.cov.ExpQuad(len(lengthscales), ls=lengthscales)
K = cov(X0.T)
K_noise = K + pm.gp.cov.WhiteNoise(noise)(X0.T)
L = np.linalg.cholesky(K_noise.eval())
alpha = np.linalg.solve(
L.T, np.linalg.solve(L, samples.flatten()))
return cov, alpha
and then use that solution to calculate the GP solution for that observable:
def gp_predictt(cov, alpha, X0, X1):
K_s = cov(X0.T, X1.T)
post_mean = theano.tensor.dot(K_s.T, alpha)
return post_mean
This code is more or less copied from the GP Slice Sampling tutorial. I don’t care about sampling, per se, from the GP at some point–I just want the most likely value of the underlying function given the input grid points.
So, from here, I can take this set of GP solutions and use them to model observations:
def find_ism_params(grid, dustlaw, line_obs, line_ls, drpall_row):
'''
run a pymc3 grid on a whole galaxy
- grid: yields from pre-GP-trained photoionization grid
- dustlaw: dust attenuation function
- line_obs: emission-line observation object
- line_ls: line wavelength
- drpall_row: metadata
'''
zdist = drpall_row['nsa_zdist']
four_pi_r2 = (4. * np.pi * cosmo.luminosity_distance(zdist)**2.).to(units.cm**2).value
# get observations
f, unc, _ = map(lambda x: np.stack(x).T, line_obs.get_good_obs())
cov = np.stack([np.diag(u**2.) for u in unc], axis=-1)
f, unc, cov = f[:1], unc[:1], cov[..., :1]
*obs_shape_, nlines = f.shape
obs_shape = tuple(obs_shape_)
print('in galaxy: {} measurements of {} lines'.format(obs_shape, nlines))
with pymc3.Model() as model:
# priors
## first on photoionization model
logZ = pymc3.Uniform('logZ', *grid.range('logZ'), shape=obs_shape,
testval=0.)
logU = pymc3.Uniform('logU', *grid.range('logU'), shape=obs_shape,
testval=-2.5)
age = pymc3.Uniform('age', *grid.range('Age'), shape=obs_shape,
testval=4.)
grid_params = theano.tensor.stack([logZ, logU, age], axis=0)
# next on normalization of emission line strengths
logQH = pymc3.Normal('logQH', 50., 3., shape=obs_shape + (1, ), testval=50)
linelumsperqh = grid.predictt(grid_params) # calls GP grid for all observables
linelums = linelumsperqh * 10**logQH # scale factor for luminosities
## next on dust model
extinction_at_AV1 = theano.shared( # shape (nlines, )
dustlaw(wave=line_ls, a_v=1., r_v=3.1))
AV = pymc3.Exponential( # shape (*obs_shape, )
'AV', 3., shape=obs_shape, testval=1.) # extinction in V-band
twopointfive = theano.shared(2.5)
A_lambda = theano.tensor.outer(AV, extinction_at_AV1)
atten = 10**(-A_lambda / twopointfive)
# dim lines based on distance
distmod = theano.shared(four_pi_r2)
one_e_minus17 = theano.shared(1.0e-17)
linefluxes = linelums * atten / distmod / one_e_minus17
ln_unc_underestimate_factor = pymc3.Uniform(
'ln-unc-underestimate', -10., 10., testval=0.)
linefluxes_obs = pymc3.Normal(
'fluxes-obs', mu=linefluxes,
sd=unc * theano.tensor.exp(ln_unc_underestimate_factor),
observed=f, shape=f.shape)
map_start = pymc3.find_MAP()
trace = pymc3.sample(draws=5000, tune=500, cores=6, start=map_start)
return model, trace
Ultimately, I’m saying I have predictions for 9 observables, which depend on ~5 unobserved quantities that I’m trying to model. This seems to work well when there’s just one or two sets of observations (when f
has a shape like (a few, 9)
); however, I am trying to run this analysis simultaneously for ~2000 sets of observations (I think these can be referred to as “plates”, but I’m not sure I’m using that term right). In that case, this is super slow (30 sec per step). This would only make sense if predictt
were being run once per chain per step.
Can anyone offer suggestions for how to improve this? Is there a way to specify that the GP call can be sped up by pooling together the calls for all the chains? Many thanks in advance for your thoughts.