That seems to work, thanks so much! I’m providing my final toy model here for future reference. (Note that I’m actually censoring 1, 2, 3 and not 1…9 as described in the initial problem description.)
"""Toy model for censored data"""
import numpy as np
import pymc as pm
true_rate = 0.25
N = 10_000
np.random.seed(45678)
customers = np.random.uniform(low=0, high=25, size=N)
complaints = np.random.poisson(lam=customers * true_rate)
def censor_data():
measured = (customers >= 4) & ((complaints == 0) | (complaints >= 4))
n = customers[measured]
e = complaints[measured]
return n, e
# noinspection PyTypeChecker, PyUnresolvedReferences
def do_sample_naive(n, e):
assert len(n) == len(e)
print(f"{len(n)} observations used")
print(n[:20])
print(e[:20])
with pm.Model():
r = pm.Exponential("r", 10)
pm.Poisson("events", mu=n * r, observed=e)
trace = pm.sample(2000, tune=1000)
print(pm.summary(trace))
def sample_censored_wrong():
n, e = censor_data()
do_sample_naive(n, e)
def sample_uncensored():
do_sample_naive(customers, complaints)
# noinspection PyTypeChecker,PyUnresolvedReferences
def sample_censored_correct():
ns, es = censor_data()
print(f"{len(ns)} observations used, {(es == 0).sum()} times e=0, {(es > 0).sum()} times e>0")
print(ns[:20])
print(es[:20])
assert not np.any((es == 1) | (es == 2) | (es == 3)), "Invalid data"
with pm.Model():
def truncated_poisson_like(e_, n_, r_):
lbda = n_ * r_
log_like = pm.logp(pm.Poisson.dist(mu=lbda), e_)
# Normalization for the truncated values: subtract probabilities for 1, 2 and 3
prob_sum = pm.math.exp(-lbda) * (lbda + lbda**2 / 2 + lbda**3 / 6)
norm_constant = 1 - prob_sum
return log_like - pm.math.log(norm_constant)
r = pm.Exponential("r", 10)
pm.CustomDist(
"observed",
ns,
r,
logp=truncated_poisson_like,
observed=es,
)
trace = pm.sample(2000, tune=1000)
print(pm.summary(trace))
if __name__ == "__main__":
print("\n\n🟢 Uncensored:\n")
sample_uncensored()
print("\n\n🔴 Censored, wrong:\n")
sample_censored_wrong()
print("\n\n🟣 Censored, correct:\n")
sample_censored_correct()
Output:
🟢 Uncensored:
10000 observations used
[19.03993623 6.73949587 4.74468214 13.56126838 23.20609131 9.04828661
14.07107776 13.43000145 15.95306611 15.43764355 18.42642233 3.9369657
21.62208019 15.67332437 10.78355162 6.67373619 3.29083908 4.82998659
6.7086113 11.81542204]
[8 1 2 3 5 1 0 0 6 5 3 0 6 1 1 1 1 3 4 3]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [r]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.[12000/12000 00:01<00:00 Sampling 4 chains, 0 divergences]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
r 0.25 0.001 0.247 0.252 0.0 0.0 3390.0 5606.0 1.0
🔴 Censored, wrong:
4465 observations used
[19.03993623 23.20609131 14.07107776 13.43000145 15.95306611 15.43764355
21.62208019 6.7086113 8.04924026 19.74827761 19.2941839 8.66510745
17.02242486 18.85849719 9.80728482 17.10877876 16.47295421 10.88572265
7.87036873 18.72706509]
[ 8 5 0 0 6 5 6 4 0 7 6 0 8 5 4 6 5 4 4 16]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [r]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.[12000/12000 00:00<00:00 Sampling 4 chains, 0 divergences]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
r 0.3 0.002 0.296 0.303 0.0 0.0 3033.0 5039.0 1.0
🟣 Censored, correct:
4465 observations used, 572 times e=0, 3893 times e>0
[19.03993623 23.20609131 14.07107776 13.43000145 15.95306611 15.43764355
21.62208019 6.7086113 8.04924026 19.74827761 19.2941839 8.66510745
17.02242486 18.85849719 9.80728482 17.10877876 16.47295421 10.88572265
7.87036873 18.72706509]
[ 8 5 0 0 6 5 6 4 0 7 6 0 8 5 4 6 5 4 4 16]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [r]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 2 seconds.[12000/12000 00:01<00:00 Sampling 4 chains, 0 divergences]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
r 0.252 0.002 0.249 0.256 0.0 0.0 3587.0 5185.0 1.0
This is with pymc 5.9.1