Hi! I’m trying to calibrate a physical model using PyMC. I’ve gone with the approach of wrapping my jax model in a PyTensor Op. This works fine with the Metropolis sampler which is finding some issues when sampling. Since I want to take advantage of NUTS, I implemented a grad method and a grad Op for my model:
@jax.jit
def impedance_logp(observed, theta, ws):
"""Compute the marginal log-likelihood of a single HMM process."""
(h1, h2, sigma, mu1, mu2) = theta[0]
model = impedance_operation(ws, h1, h2, mu1, mu2)
logp_impedance = -(0.5 / sigma**2) * jnp.sum((observed - model) ** 2)
return jnp.array(logp_impedance)
class ImpedanceOp(Op):
itypes = [pt.dvector] # expects a vector of parameter values when called
otypes = [pt.dscalar] # outputs a single scalar value (the log likelihood)
def __init__(self, ws, data, **kwargs):
super().__init__(**kwargs)
self.ws = ws
self.data = data
self.logpgrad = ImpedanceGrad(self.data, self.ws)
def impedance_logp(self, theta):
return impedance_logp(self.data, theta, self.ws)
def perform(self, node, inputs, outputs):
# (h,) = inputs
result = self.impedance_logp(inputs)
outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
def grad(self, inputs, output_gradients):
grads = self.logpgrad(*inputs)
output_gradient = output_gradients[0]
return [output_gradient * grads]
class ImpedanceGrad(Op):
itypes = [pt.dvector]
otypes = [pt.dvector]
def __init__(self, data, ws) -> None:
super().__init__()
self.data = data
self.ws = ws
def impedance_logp(self, theta):
return impedance_logp(self.data, theta, self.ws)
def impedance_logp_grad(self, theta):
return jax.grad(self.impedance_logp)(theta)
def perform(self, node, inputs, outputs):
grads = self.impedance_logp_grad(inputs)
outputs[0][0] = np.asarray(grads, dtype=node.outputs[0].dtype)
And here is the model I’m trying to sample:
with pm.Model() as impedance_calibration:
h1 = pm.HalfNormal("h1", sigma=1)
h2 = pm.HalfNormal("h2", sigma=1)
mu1 = pm.HalfNormal("mu1", sigma=1)
mu2 = pm.HalfNormal("mu2", sigma=1)
# sigma = pm.TruncatedNormal("sigma", sigma=1, mu=5, lower=0, upper=10)
sigma = pm.HalfNormal("sigma", sigma=1)
theta = pt.as_tensor_variable([h1, h2, sigma, mu1, mu2])
likelihood = pm.Potential(
"likelihood",
impedance_mean(theta),
)
if not os.path.exists(TRACE_PATH):
# step = pm.Metropolis()
# trace = pm.sample(step=step, tune=10000, draws=5000)
trace = pm.sample(tune=10000, draws=5000)
# trace = pm.sample_smc(draws=5000, chains=4, parallel=True)
trace.to_netcdf(TRACE_PATH)
else:
trace = az.from_netcdf(TRACE_PATH)
Problem is that sampling with nuts yields the following error:
pymc.sampling.parallel.ParallelSamplingError: Chain 0 failed with: Expected 1 dimensions input
Apply node that caused the error: Subtensor{i}(ImpedanceGrad.0, 0)
Toposort index: 10
Inputs types: [TensorType(float64, shape=(None,)), ScalarType(uint8)]
Inputs shapes: [(1, 5), ()]
Inputs strides: [(40, 8), ()]
Inputs values: [array([[ 4.21017600e+08, -1.44086746e+09, 2.44285594e+09,
-1.65140378e+09, -8.00240256e+08]]), 0]
Outputs clients: [[Composite{(((i0 + i1) * i2) + i3)}(Composite{...}.2, Subtensor{i}.0, h1, (d__logp/dh1_log___logprob){1.0})]]
Which I need some help to interpret exactly which of the variables is causing trouble
Thanks for the help!