Mixture of Censored iid Normals

After Expand models supported by automatic marginalization by ricardoV94 · Pull Request #300 · pymc-devs/pymc-experimental · GitHub gets merged, this should work:

import numpy as np
from pymc_experimental import MarginalModel
from pymc.distributions.transforms import ordered

data = np.array([[-1., -1.], [0., 0.], [1., 1.], [2., 2.], [-3., -3.]] * 2).T
nobs = data.shape[-1]
n_clusters = 5
coords = {
    "cluster": range(n_clusters), 
    "ndim": ("x", "y"), 
    "obs": range(nobs),
}
with MarginalModel(coords=coords) as m:
    idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
    mu_x = pm.Normal("mu_x", dims=["cluster"], transform=ordered, initval=np.linspace(-1, 1, n_clusters))
    mu_y = pm.Normal("mu_y", dims=["cluster"])
    mu = pm.math.concatenate([mu_x[None], mu_y[None]], axis=0)
        
    sigma = pm.HalfNormal("sigma")
    y = pm.Censored(
        "y",
        dist=pm.Normal.dist(mu[:, idx], sigma),
        lower=-3, 
        upper=3,
        observed=data,
        dims=["ndim", "obs"],
    )

    # Marginalize away the idx, creating a Mixture likelihood
    m.marginalize(idx)
    idata = pm.sample()

No need for the dummy CustomDist. I added this model as one of the new tests.

I also confirmed it’s using NUTS in my machine.

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu_x, mu_y, sigma]