Is this a censored-data problem or not?

That seems to work, thanks so much! I’m providing my final toy model here for future reference. (Note that I’m actually censoring 1, 2, 3 and not 1…9 as described in the initial problem description.)

"""Toy model for censored data"""

import numpy as np
import pymc as pm

true_rate = 0.25
N = 10_000

np.random.seed(45678)

customers = np.random.uniform(low=0, high=25, size=N)
complaints = np.random.poisson(lam=customers * true_rate)


def censor_data():
    measured = (customers >= 4) & ((complaints == 0) | (complaints >= 4))
    n = customers[measured]
    e = complaints[measured]
    return n, e


# noinspection PyTypeChecker, PyUnresolvedReferences
def do_sample_naive(n, e):
    assert len(n) == len(e)
    print(f"{len(n)} observations used")
    print(n[:20])
    print(e[:20])
    with pm.Model():
        r = pm.Exponential("r", 10)
        pm.Poisson("events", mu=n * r, observed=e)
        trace = pm.sample(2000, tune=1000)
        print(pm.summary(trace))


def sample_censored_wrong():
    n, e = censor_data()
    do_sample_naive(n, e)


def sample_uncensored():
    do_sample_naive(customers, complaints)


# noinspection PyTypeChecker,PyUnresolvedReferences
def sample_censored_correct():
    ns, es = censor_data()

    print(f"{len(ns)} observations used, {(es == 0).sum()} times e=0, {(es > 0).sum()} times e>0")
    print(ns[:20])
    print(es[:20])

    assert not np.any((es == 1) | (es == 2) | (es == 3)), "Invalid data"

    with pm.Model():

        def truncated_poisson_like(e_, n_, r_):
            lbda = n_ * r_
            log_like = pm.logp(pm.Poisson.dist(mu=lbda), e_)

            # Normalization for the truncated values: subtract probabilities for 1, 2 and 3
            prob_sum = pm.math.exp(-lbda) * (lbda + lbda**2 / 2 + lbda**3 / 6)
            norm_constant = 1 - prob_sum
            return log_like - pm.math.log(norm_constant)

        r = pm.Exponential("r", 10)
        pm.CustomDist(
            "observed",
            ns,
            r,
            logp=truncated_poisson_like,
            observed=es,
        )

        trace = pm.sample(2000, tune=1000)
        print(pm.summary(trace))


if __name__ == "__main__":
    print("\n\n🟢 Uncensored:\n")
    sample_uncensored()

    print("\n\n🔴 Censored, wrong:\n")
    sample_censored_wrong()

    print("\n\n🟣 Censored, correct:\n")
    sample_censored_correct()

Output:

🟢 Uncensored:

10000 observations used
[19.03993623  6.73949587  4.74468214 13.56126838 23.20609131  9.04828661
 14.07107776 13.43000145 15.95306611 15.43764355 18.42642233  3.9369657
 21.62208019 15.67332437 10.78355162  6.67373619  3.29083908  4.82998659
  6.7086113  11.81542204]
[8 1 2 3 5 1 0 0 6 5 3 0 6 1 1 1 1 3 4 3]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [r]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.[12000/12000 00:01<00:00 Sampling 4 chains, 0 divergences]
   mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
r  0.25  0.001   0.247    0.252        0.0      0.0    3390.0    5606.0    1.0


🔴 Censored, wrong:

4465 observations used
[19.03993623 23.20609131 14.07107776 13.43000145 15.95306611 15.43764355
 21.62208019  6.7086113   8.04924026 19.74827761 19.2941839   8.66510745
 17.02242486 18.85849719  9.80728482 17.10877876 16.47295421 10.88572265
  7.87036873 18.72706509]
[ 8  5  0  0  6  5  6  4  0  7  6  0  8  5  4  6  5  4  4 16]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [r]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.[12000/12000 00:00<00:00 Sampling 4 chains, 0 divergences]
   mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
r   0.3  0.002   0.296    0.303        0.0      0.0    3033.0    5039.0    1.0


🟣 Censored, correct:

4465 observations used, 572 times e=0, 3893 times e>0
[19.03993623 23.20609131 14.07107776 13.43000145 15.95306611 15.43764355
 21.62208019  6.7086113   8.04924026 19.74827761 19.2941839   8.66510745
 17.02242486 18.85849719  9.80728482 17.10877876 16.47295421 10.88572265
  7.87036873 18.72706509]
[ 8  5  0  0  6  5  6  4  0  7  6  0  8  5  4  6  5  4  4 16]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [r]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 2 seconds.[12000/12000 00:01<00:00 Sampling 4 chains, 0 divergences]
    mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
r  0.252  0.002   0.249    0.256        0.0      0.0    3587.0    5185.0    1.0

This is with pymc 5.9.1

2 Likes