After Expand models supported by automatic marginalization by ricardoV94 · Pull Request #300 · pymc-devs/pymc-experimental · GitHub gets merged, this should work:
import numpy as np
from pymc_experimental import MarginalModel
from pymc.distributions.transforms import ordered
data = np.array([[-1., -1.], [0., 0.], [1., 1.], [2., 2.], [-3., -3.]] * 2).T
nobs = data.shape[-1]
n_clusters = 5
coords = {
"cluster": range(n_clusters),
"ndim": ("x", "y"),
"obs": range(nobs),
}
with MarginalModel(coords=coords) as m:
idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
mu_x = pm.Normal("mu_x", dims=["cluster"], transform=ordered, initval=np.linspace(-1, 1, n_clusters))
mu_y = pm.Normal("mu_y", dims=["cluster"])
mu = pm.math.concatenate([mu_x[None], mu_y[None]], axis=0)
sigma = pm.HalfNormal("sigma")
y = pm.Censored(
"y",
dist=pm.Normal.dist(mu[:, idx], sigma),
lower=-3,
upper=3,
observed=data,
dims=["ndim", "obs"],
)
# Marginalize away the idx, creating a Mixture likelihood
m.marginalize(idx)
idata = pm.sample()
No need for the dummy CustomDist. I added this model as one of the new tests.
I also confirmed it’s using NUTS in my machine.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu_x, mu_y, sigma]