Modelling a Mixture Model

Hi! I’m trying to model a Gaussian mixture model using the stick breaking process:
My model goes as follows:

    with pm.Model() as st_model:
        alpha = pm.Gamma("alpha", 1.0, 1.0)
        beta = pm.Beta("beta", 1, alpha, shape=(K,))
        w = pm.Deterministic("ws", stick_breaking(beta))
        mu = pm.Gamma("mu_p", mu=0.1, sigma=0.1, shape=(K,))
        sigma = pm.Gamma("sigma_p", mu=0.1, sigma=0.1, shape=(K,))

        obs = pm.NormalMixture(
            "likelihood", w, mu, sigma=sigma, observed=data, shape=data.shape
        )

Prior predictive check seems correct as to this plot (observed data in orange):
prior_stick_breaking
That plot was made using the following function (pymc errors when I try to sample the likelihood when using a mixture model):

def sample_mixture(observed, prior_data):
    xs = np.linspace(observed.min(), 1.1*observed.max(), 2000)

    mu_p = prior_data.prior["mu_p"].mean(("chain", "draw")).to_numpy()
    sigma_p = prior_data.prior["sigma_p"].mean(("chain", "draw")).to_numpy()
    ws = prior_data.prior["ws"].mean(("chain", "draw")).to_numpy()
    
    cdfs = np.array([w*stats.norm(loc=mu, scale=s).cdf(xs) for mu, s, w in zip(mu_p, sigma_p, ws)])
    cdfs = cdfs.sum(axis=0)
    u_samps = np.linspace(0, 1, 2000)
    sampled = np.fromiter((xs[cdfs<=p][-1] for p in u_samps), dtype=np.float32)

    fig, ax = plt.subplots()
    sns.histplot({"sampled_score": sampled}, x="sampled_score", ax=ax)
    sns.histplot({"observed_score": observed}, x="observed_score", ax=ax)
    plt.savefig(PLOTS/"score_prior.png")

I’m having trouble with sampling from this model as initial evaluation fails. This hints that the prior is not fit but my plot seems to indicate the contrary.
What I’m not seeing here?

1 Like

Can you provide a runnable version of your code that produces the error you are seeing?