I think this works:
# modified from https://www.pymc.io/projects/examples/en/latest/generalized_linear_models/GLM-truncated-censored-regression.html#glm-truncated-censored-regression
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytensor.tensor as pt
from scipy.stats import norm
rng = np.random.default_rng(123)
# generate full sample, included data never observed
slope, intercept, scat, N = 1, 0, 2, 2000
x = rng.uniform(-10, 10, N)
y = rng.normal(loc=slope * x + intercept, scale=scat)
def soft_truncate_y(x, y, threshold):
# Same as 1 - norm.cdf(threshold, y, 1)
py=1 - norm.cdf(threshold - y, 0, 1)
myrand=rng.uniform(size=y.size)
keep = (py >= myrand)
return (x[keep], y[keep])
# included in the observed sample
threshold = -1
xt, yt = soft_truncate_y(x, y, threshold)
print('include in the observed sample',yt.size,'out of',y.size)
# plot all
plt.plot(x, y, ".", c=[0.7, 0.7, 0.7],label='_nolegend_')
plt.axhline(threshold, c="r", ls="--")
plt.xlabel("x"), plt.ylabel("y")
plt.plot(xt, yt, ".", c=[0, 0, 0],label='Observed data')
# (biased) fit, neglecting soft truncation
def linear_regression(x, y, adjustment=False):
def logp(value, mu, sigma):
base_logp = pm.logp(pm.Normal.dist(mu, sigma), value)
# Probability of truncation is given by logccdf of the
# convolved Normal(mu, sigma) distribution and Normal(0, 1)
# evaluated at the threshold
convolved_normal = pm.Normal.dist(mu, pt.sqrt(sigma ** 2 + 1 ** 2))
truncation_adj = pt.log1mexp(pm.logcdf(convolved_normal, threshold))
return base_logp - truncation_adj
with pm.Model() as model:
slope = pm.Normal("slope", mu=0, sigma=2)
intercept = pm.Normal("intercept", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=1)
mu = slope * x + intercept
if adjustment:
pm.CustomDist("obs", mu, sigma, logp=logp, observed=y)
else:
pm.Normal("obs", mu=mu, sigma=sigma, observed=y)
return model
for adjustment in (False, True):
with linear_regression(xt, yt, adjustment):
trunc_linear_fit = pm.sample(random_seed=rng)
pm.plots.plot_posterior(trunc_linear_fit, var_names=["intercept", "slope", "sigma"], ref_val=[intercept, slope, scat])
I modified a bit the threshold logic to be more similar to the logp. Also you were not truncating at the threshold, I imagined you wanted to do y - threshold
instead of y + threshold
?
Without adjustment:
With adjustment: