Deterministic switch is tricky during sampling with HMC/NUTS. My suggestion is to use a multiplication “trick”:
x_b = np.ones_like(x)
x_copy = x.copy()
x_copy[np.isnan(y)] = 0.
x_b[np.isnan(y)] = 0.
The result is the same as doing mu = at.switch(np.isnan(x), 0.0, a*x + b)
but without the switch
.
Then for sigma, dont create a new free random variable but instead set to a fix one (doesn’t matter what value as it just add a constant to the log_prob)
sigma_fixed = aesara.shared(1.)
with pm.Model() as model_1:
a = pm.Normal('a', 0.0, 1.0)
b = pm.Normal('b', 0.0, 1.0)
sigma = pm.Exponential('sigma', 1.0)
sigma_ = at.stack([sigma, sigma_fixed])
mu = a * x_copy + b * x_b
y_normal = pm.Normal('y_normal', mu, sigma_[np.isnan(x).astype(int)], observed=y)
idata = pm.sample()