Integrating untruncated sum to truncated model

I am using an insurance claims dataset which only shows claims over a certain size (it is truncated). However, it also includes the sum of all claims seen (but not the number). It strikes me that it would be super useful information to help reconstruct the whole distribution.

However, I’m not sure how to integrate this information when fitting my TruncatedNormal. Does anyone have any insight?

Data setup:

import numpy as np
import pymc as pm
import matplotlib.pyplot as plt
import arviz as az

print(f"pymc: {pm.__version__}")

rng = np.random.default_rng(20240802)

# insurance losses
draws = rng.lognormal(10, 3, size=100_000)

# total loss - this number is provided
total_loss = draws.sum()

# we only get individual losses reported above this number
min_reported_loss = 1_000_000

print(f"We see only {(draws > min_reported_loss).mean():%} of losses")

observed_losses = draws[draws > min_reported_loss]

print(f"Sampled mean {np.log(draws).mean():.3f}")
print(f"Sampled std {np.log(draws).std():.3f}")

plt.hist(np.log(draws), alpha=0.5)
plt.axvline(np.log(min_reported_loss), color="black", linestyle="--")

And the model:

print(f"Observed count: {len(observed_losses):,}\n")

with pm.Model() as model:
    mu = pm.Normal("mu", 0, 10)
    sigma = pm.HalfNormal("sigma", 5)
    y = pm.TruncatedNormal(
        "y",
        mu=mu,
        sigma=sigma,
        lower=np.log(min_reported_loss),
        observed=np.log(observed_losses),
    )

with model:
    trace = pm.sample()

pm.summary(trace)

In writing this post, I actually made what turns out was a great prompt, which resulted in the following solution which does seem to help a lot, especially when the number of draws reduced by a lot:

with pm.Model() as model:
    mu = pm.Normal("mu", 10, 10)
    sigma = pm.HalfNormal("sigma", 5)

    y_obs = pm.TruncatedNormal(
        "y_obs",
        mu=mu,
        sigma=sigma,
        lower=np.log(min_reported_loss),
        observed=np.log(observed_losses),
    )

    # Adding a custom likelihood for the total sum of claims
    total_loss_constraint = pm.Potential(
        "total_loss_constraint",
        pm.logp(pm.Normal.dist(mu=mu, sigma=sigma), np.log(draws)).sum() - total_loss,
    )

This sneakily does use the number of draws so isn’t actually a good solution.