Probabilistic model with deterministic switch

Hi,
I am trying to build a simple linear regression model with a switch for the observed variable, i.e., my model looks like this

a ~ N(0, 1)
b ~ N(0, 1)
mu = a*x + b

y ~ N(mu, sigma) if x != NaN else y = 0

That is, when x is NaN, y is definitely 0. Usually, I would just filter the data samples and kick out all those samples where x is NaN but in my actual dataset, I am building a multiple linear regression model (i.e., x, y are N-dimensional vectors) and the cases where some x_i is NaN need to be used for the fitting.

Here is my simple code example:

import pymc as pm
import numpy as np
import aesara.tensor as at
import arviz as az
# Simulate some data
N_samples = 100
p_NaN = 0.1
rng = np.random.RandomState(123)
x = rng.normal(10., 2., size=N_samples)
x[rng.rand(N_samples) < p_NaN] = np.nan

a, b = 2., 1.
sigma = 2.
y = a*x + b + sigma * rng.normal(0.0, 1.0, size=N_samples)
y[np.isnan(y)] = 0.0

Naively, I would have used something like this:

with pm.Model() as model_1:
    a = pm.Normal('a', 0.0, 1.0)
    b = pm.Normal('b', 0.0, 1.0)
    sigma = pm.Exponential('sigma', 1.0)
    mu = a*x + b
    y_normal = pm.Normal('y_normal', mu, sigma)

    pm.Deterministic('y', tt.switch(np.isnan(x), 0.0, y_normal), observed=y)

However, this is not possible because we cannot sample from deterministic variables in pymc3.

So, here is my attempt using aesara.tensor.switch():

with pm.Model() as model_1:
    a = pm.Normal('a', 0.0, 1.0)
    b = pm.Normal('b', 0.0, 1.0)
    sigma1 = pm.Exponential('sigma', 1.0)
    sigma2 = pm.Exponential('sigma2', 1.0)
    mu = at.switch(np.isnan(x), 0.0, a*x + b)
    sigma = at.switch(np.isnan(x), sigma2, sigma1)
    pm.Normal('y', mu, sigma, observed=y)

    pm.sample(target_accept=0.99)

As you can see, I’m here using some construct where I switch the mean value to 0.0 if x is NaN, and also learn a different sigma value for those cases.
However, sampling from this model leads to many divergences:

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 80 seconds.
There were 767 divergences after tuning. Increase `target_accept` or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
There were 789 divergences after tuning. Increase `target_accept` or reparameterize.
There were 872 divergences after tuning. Increase `target_accept` or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
There were 789 divergences after tuning. Increase `target_accept` or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.

Is there a better way to build such a model or parametrize it?

Deterministic switch is tricky during sampling with HMC/NUTS. My suggestion is to use a multiplication “trick”:

x_b = np.ones_like(x)
x_copy = x.copy()
x_copy[np.isnan(y)] = 0.
x_b[np.isnan(y)] = 0.

The result is the same as doing mu = at.switch(np.isnan(x), 0.0, a*x + b) but without the switch.
Then for sigma, dont create a new free random variable but instead set to a fix one (doesn’t matter what value as it just add a constant to the log_prob)

sigma_fixed = aesara.shared(1.)

with pm.Model() as model_1:
    a = pm.Normal('a', 0.0, 1.0)
    b = pm.Normal('b', 0.0, 1.0)
    sigma = pm.Exponential('sigma', 1.0)
    sigma_ = at.stack([sigma, sigma_fixed])
    mu = a * x_copy + b * x_b
    y_normal = pm.Normal('y_normal', mu, sigma_[np.isnan(x).astype(int)], observed=y)
    idata = pm.sample()

Thank you for reply, this does work, indeed. One thing I noticed is that I need to increase the number of tuning steps quite a lot, otherwise the acceptance probability is too low. Also, in your example the sample with X being NaN do have a variance of 1. If I decrease this to better approximate the true value of exactly 0, the sampling leads to very low acceptance probabilities, no matter how high I go with the tuning steps (I tried up to 10000).

Probably running into numerical issue - I think you already observed that changing that value does not change the posterior, as they are constant added to the log_prob (for your model, it is equivalent to removing the NaNs)

Yes, true. It is just a bit inconvenient for predicting values because the posterior prediction for values with x == NaN will have that variance added. Of course, I can manually correct for it, but it ideally, it would be nicer to not be required to do that.