Hi, I would like to solve an inference problem with NUTS. I have observed data that, after uncertainty propagation from input parameters not designated for updating, has a very irregular distribution. How do I assimilate observed data with irregular distribution?
To be more specific, the observed data is modeled with equations:
x1 = 3*a**2 + 2*b**2 + c + d**2
x2 = a**(3/2) + b**2 + 2*c + d
x3 = a**(5/2) + 4*b**(3/2) + 2*c + 2*d
where a and b are designated for Bayesian updating while c and d are not, so their uncertainties are propagated to observed data. Measured values of observed data = [1238.5, 215, 1544], with measurement uncertainty described by a normal distribution with sd=1. All distributions of a,b,c,d are uniform. The uncertainty propagation is done by simulating observed data with sampling of c,d and measurement uncertainty, and keeping a,b constant:
c = np.random.uniform(5, 9, 10000)
d = np.random.uniform(2, 6, 10000)
a = 14
b = 16
samples1 = a*a*3+b*b*2+c+d**2+np.random.normal(0, 1, 10000)
samples2 = a**(3/2)+b*b*1+c*2+d+np.random.normal(0, 1, 10000)
samples3 = a**(5/2)+b**(3/2)*4+c*2+d*2+np.random.normal(0, 1, 10000)
This problem can be solved with smc-abc without explicit error propagation (by modeling the un-updated parameters as noise in the simulator), but I am curious if it is possible with NUTS because it is much more efficient.
The SMC-ABC code:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
az.style.use("arviz-darkgrid")
data = [1238.5, 215, 1544]
rng = np.random.default_rng()
def normal_sim(rng, a, b, size=1):
c = rng.uniform(5, 9)
d = rng.uniform(2, 6)
return [3*a*a+2*b*b+c+d**2, a**(3/2)+b*b+2*c+d, a**(5/2)+4*b**(3/2)+2*c+2*d]
with pm.Model() as example:
a = pm.Uniform("a", lower=0.1*16, upper=3*16)
b = pm.Uniform("b", lower=0.1*14, upper=3*14)
s = pm.Simulator("s", normal_sim, params=(a, b), sum_stat="sort", epsilon=1, observed=data)
idata = pm.sample_smc(10000, cores=1, chains=1, threshold=0.6)
idata.extend(pm.sample_posterior_predictive(idata))
az.plot_trace(idata, kind="rank_vlines");
print(az.summary(idata, kind="stats"))