Help with fitting mixture -- data is fit much worse than fitting the distributions separately

Hello,

I am having an issue where some of my data does not fit how I expected. I went back to simulate data from scratch, and the same thing happens there. I have constructed the following examples and figures to illustrate the problem. To state the issue:

  • I can fit component 1 on its own
  • I can fit component 2 on its own
  • fitting the mixture is horrible

The code below produces figures for debugging too since I cant seem to attach figures (let me know if there is a way to do that)

Here is the simulation:

  • I have data belonging to two classes
  • Noise follows a Binomial (n=N, p=p0)
  • Signal follows a Binomial (n=N, p=p1*fraction)
  • Class 0 only draws from the noise binomial – its fraction is set to 0 for mixture fitting
  • Class 1 draws from both distributions and we observe the sum.

In theory, this to me means that I should be able to either:

  • Fit class 0 only, fit class 1 only, and then draw from the posterior to see how well we fit
  • Fit them as a mixture, but instead model signal as Binomial(n=N,p=p1) and make the mixing weights (fraction, 1-fraction) since this should carry through to the binomial via the model counts = 1-frac * Bin(N,p0) + frac * Bin(N,p1)

Here is the full code to reproduce the issue:
Simulating data:

import arviz as az
import pandas as pd
import scipy.stats as ss
import numpy as np

import pymc as pm
import pytensor.tensor as pt

background_rate = 1e-4
signal_rate=5e-2

simulated_data = pd.DataFrame(
    data = [ss.norm.rvs(loc = 10_000, scale = 1_500, size = 500)],
    index = ["N"]
).T

simulated_data["N"] = simulated_data["N"].astype(int)
simulated_data["sample_type"] = [np.random.randint(2) for i in range(len(simulated_data))]

simulated_data["fraction"] = ss.uniform.rvs(size=500)

simulated_data["background_counts"] = ss.binom.rvs(n=simulated_data["N"],p=background_rate)

simulated_data["signal_counts"] = ss.binom.rvs(n=simulated_data["N"],p=simulated_data["fraction"]*signal_rate)
cnd = simulated_data["sample_type"] == 0
simulated_data.loc[cnd,"signal_counts"] = 0

simulated_data["observed_counts"] = simulated_data["signal_counts"] + simulated_data["background_counts"]

for col in simulated_data:
    plt.figure(figsize=(7,5))
    sns.histplot(simulated_data[col])

Helper function

def save_model_params(idata):
    params = {}
    for param_name in list(idata.posterior.data_vars.keys()):
        param_value = idata.posterior.data_vars[param_name].mean().values
        params[param_name]=param_value

    return params

Separately fitting models:

%%capture --no-display
scores = simulated_data[simulated_data["sample_type"]==0].reset_index(drop=True)
coords = {"observation": scores.index.values}
for mode in ["fit","predict"]:
    model = pm.Model(coords=coords)
    with model:

        n = pm.ConstantData("n",value=scores["N"],dims="observation")
        frac = pm.ConstantData("frac",value=scores["fraction"],dims="observation")

        if mode == "fit":
            p = pm.Uniform("p")
            observed = scores["observed_counts"]
            dims = "observation"

        else:
            p = params["p"]
            observed = None
            dims = None

        counts = pm.Binomial('counts',
                       n = n,
                       p = p, 
                       observed=observed,
                       dims=dims
                    )

        with model:
            idata = pm.sample()
            pm.sample_posterior_predictive(idata, extend_inferencedata=True)

        params = save_model_params(idata)
        display(params)

    if "fit" in mode:
        display(az.plot_ppc(idata))
    else:
        predicted_counts = list(idata.posterior.data_vars["counts"].mean(axis=(0,1)).astype(int).values+1)
        observed_counts = list(scores["observed_counts"])
        tf = list(scores["fraction"])

        df = scores.copy()
        df["observed"] = observed_counts
        df["predicted"] = predicted_counts
        df["sample_type"] = df["sample_type"].astype(str)
        df["obs-pred_sign"] = (((df["observed"] - df["predicted"]))>0).apply(lambda x: 1 if x else -1)
        df["obs-pred_value"] = ((df["observed"] - df["predicted"]).abs()+1).apply(np.log10)
        df["obs-pred"] = df["obs-pred_sign"] * df["obs-pred_value"]

        fig = px.scatter(
            data_frame = df,
            x = "observed",
            y = "predicted",
            color = "sample_type",
            width = 750,
            height = 750,
            log_x=True,
            log_y=True,
        )

        fig.add_shape(type="line",line=dict(dash="dash",color="black"),x0=1,y0=1,x1=10_000,y1=10_000)
        display(fig)


        fig = px.scatter(
            data_frame = df,
            x = "observed",
            y = "fraction",
            color = "sample_type",
            width = 750,
            height = 750,
            log_x=True,
            log_y=True,
        )

        display(fig)

        plt.figure(figsize=(7,5))
        display(sns.distplot(df["obs-pred"]))

        display(df["obs-pred"].abs().describe(percentiles=[i/100 for i in range(0,100,10)]))
%%capture --no-display
scores = simulated_data[simulated_data["sample_type"]==1].reset_index(drop=True)
coords = {"observation": scores.index.values}
for mode in ["fit","predict"]:
    model = pm.Model(coords=coords)
    with model:

        n = pm.ConstantData("n",value=scores["N"],dims="observation")
        frac = pm.ConstantData("frac",value=scores["fraction"],dims="observation")

        if mode == "fit":
            p = pm.Uniform("p")
            observed = scores["observed_counts"]
            dims = "observation"

        else:
            p = params["p"]
            observed = None
            dims = None

        counts = pm.Binomial('counts',
                       n = n,
                       p = p*frac, 
                       observed=observed,
                       dims=dims
                    )

        with model:
            idata = pm.sample()
            pm.sample_posterior_predictive(idata, extend_inferencedata=True)

        params = save_model_params(idata)
        display(params)

    if "fit" in mode:
        display(az.plot_ppc(idata))
    else:
        predicted_counts = list(idata.posterior.data_vars["counts"].mean(axis=(0,1)).astype(int).values+1)
        observed_counts = list(scores["observed_counts"])
        tf = list(scores["fraction"])

        df = scores.copy()
        df["observed"] = observed_counts
        df["predicted"] = predicted_counts
        df["sample_type"] = df["sample_type"].astype(str)
        df["obs-pred_sign"] = (((df["observed"] - df["predicted"]))>0).apply(lambda x: 1 if x else -1)
        df["obs-pred_value"] = ((df["observed"] - df["predicted"]).abs()+1).apply(np.log10)
        df["obs-pred"] = df["obs-pred_sign"] * df["obs-pred_value"]

        fig = px.scatter(
            data_frame = df,
            x = "observed",
            y = "predicted",
            color = "sample_type",
            width = 750,
            height = 750,
            log_x=True,
            log_y=True,
        )

        fig.add_shape(type="line",line=dict(dash="dash",color="black"),x0=1,y0=1,x1=10_000,y1=10_000)
        display(fig)


        fig = px.scatter(
            data_frame = df,
            x = "observed",
            y = "fraction",
            color = "sample_type",
            width = 750,
            height = 750,
            log_x=True,
            log_y=True,
        )

        display(fig)

        plt.figure(figsize=(7,5))
        display(sns.distplot(df["obs-pred"]))

        display(df["obs-pred"].abs().describe(percentiles=[i/100 for i in range(0,100,10)]))

Fitting as a mixture

%%capture --no-display
scores = simulated_data[simulated_data["sample_type"]==1].reset_index(drop=True)
cnd = scores["sample_type"] == 0
scores.loc[cnd,"fraction"] = 0
coords = {"observation": scores.index.values}
for mode in ["fit","predict"]:
    model = pm.Model(coords=coords)
    with model:

        n = pm.ConstantData("n",value=scores["N"],dims="observation")
        frac = pm.ConstantData("frac",value=scores["fraction"],dims="observation")

        if mode == "fit":
            p_type0 = pm.Uniform("p_type0")
            p_type1 = pm.Uniform("p_type1")
            observed = scores["observed_counts"]
            dims = "observation"

        else:
            p_type0 = params["p_type0"]
            p_type1 = params["p_type1"]
            observed = None
            dims = None

        component_t0 = pm.Binomial.dist(
                       n = n,
                       p = p_type0, 
                    )
        
        component_t1 = pm.Binomial.dist(
                       n = n,
                       p = p_type1, 
                    )
        
        observed_counts = pm.Mixture(
            "counts",
            w = pt.stack([frac,1-frac],axis=1),
            comp_dists = [
                component_t1,
                component_t0
            ],
            observed = observed,
            dims = dims,
        )

        with model:
            idata = pm.sample()
            pm.sample_posterior_predictive(idata, extend_inferencedata=True)

        params = save_model_params(idata)
        display(params)

    if "fit" in mode:
        display(az.plot_ppc(idata))
    else:
        predicted_counts = list(idata.posterior.data_vars["counts"].mean(axis=(0,1)).astype(int).values+1)
        observed_counts = list(scores["observed_counts"])
        tf = list(scores["fraction"])

        df = scores.copy()
        df["observed"] = observed_counts
        df["predicted"] = predicted_counts
        df["sample_type"] = df["sample_type"].astype(str)
        df["obs-pred_sign"] = (((df["observed"] - df["predicted"]))>0).apply(lambda x: 1 if x else -1)
        df["obs-pred_value"] = ((df["observed"] - df["predicted"]).abs()+1).apply(np.log10)
        df["obs-pred"] = df["obs-pred_sign"] * df["obs-pred_value"]

        fig = px.scatter(
            data_frame = df,
            x = "observed",
            y = "predicted",
            color = "sample_type",
            width = 750,
            height = 750,
            log_x=True,
            log_y=True,
        )

        fig.add_shape(type="line",line=dict(dash="dash",color="black"),x0=1,y0=1,x1=10_000,y1=10_000)
        display(fig)


        fig = px.scatter(
            data_frame = df,
            x = "observed",
            y = "fraction",
            color = "sample_type",
            width = 750,
            height = 750,
            log_x=True,
            log_y=True,
        )

        display(fig)

        plt.figure(figsize=(7,5))
        display(sns.distplot(df["obs-pred"]))

        display(df["obs-pred"].abs().describe(percentiles=[i/100 for i in range(0,100,10)]))