I am trying to build a simple linear regression model with a switch for the observed variable, i.e., my model looks like this
a ~ N(0, 1)
b ~ N(0, 1)
mu = a*x + b
y ~ N(mu, sigma) if x != NaN else y = 0
That is, when x is NaN, y is definitely 0. Usually, I would just filter the data samples and kick out all those samples where x is NaN
but in my actual dataset, I am building a multiple linear regression model (i.e., x, y are N-dimensional vectors) and the cases where some x_i is NaN
need to be used for the fitting.
Here is my simple code example:
import pymc as pm
import numpy as np
import aesara.tensor as at
import arviz as az
# Simulate some data
N_samples = 100
p_NaN = 0.1
rng = np.random.RandomState(123)
x = rng.normal(10., 2., size=N_samples)
x[rng.rand(N_samples) < p_NaN] = np.nan
a, b = 2., 1.
sigma = 2.
y = a*x + b + sigma * rng.normal(0.0, 1.0, size=N_samples)
y[np.isnan(y)] = 0.0
Naively, I would have used something like this:
with pm.Model() as model_1:
a = pm.Normal('a', 0.0, 1.0)
b = pm.Normal('b', 0.0, 1.0)
sigma = pm.Exponential('sigma', 1.0)
mu = a*x + b
y_normal = pm.Normal('y_normal', mu, sigma)
pm.Deterministic('y', tt.switch(np.isnan(x), 0.0, y_normal), observed=y)
However, this is not possible because we cannot sample from deterministic variables in pymc3
So, here is my attempt using aesara.tensor.switch()
with pm.Model() as model_1:
a = pm.Normal('a', 0.0, 1.0)
b = pm.Normal('b', 0.0, 1.0)
sigma1 = pm.Exponential('sigma', 1.0)
sigma2 = pm.Exponential('sigma2', 1.0)
mu = at.switch(np.isnan(x), 0.0, a*x + b)
sigma = at.switch(np.isnan(x), sigma2, sigma1)
pm.Normal('y', mu, sigma, observed=y)
As you can see, I’m here using some construct where I switch the mean value to 0.0 if x is NaN, and also learn a different sigma value for those cases.
However, sampling from this model leads to many divergences:
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 80 seconds.
There were 767 divergences after tuning. Increase `target_accept` or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
There were 789 divergences after tuning. Increase `target_accept` or reparameterize.
There were 872 divergences after tuning. Increase `target_accept` or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
There were 789 divergences after tuning. Increase `target_accept` or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
Is there a better way to build such a model or parametrize it?