RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
N = 1000
W = np.array([0.5, 0.5])
μ_orig = np.array([0.5, 2.5])
σ_orig = np.array([0.45, 0.5])
component = rng.choice(μ_orig.size, size=N, p=W)
data = rng.normal(μ_orig[component], σ_orig[component], size=N)
with pm.Model(coords={"cluster": np.arange(len(W)), "data_id": np.arange(N)}) as model_multi:
w = pm.Dirichlet("w", np.ones_like(W))
mu = pm.Normal(
"mu",
np.zeros_like(W),
1.0,
dims="cluster",
transform=pm.distributions.transforms.ordered,
initval=[1, 2],
)
tau = pm.Gamma("tau", 1.0, 1.0, dims="cluster")
data_obs = pm.NormalMixture("data_obs", w, mu, tau=tau, observed=data, dims="data_id")
trace = pm.sample(5000, n_init=10000, tune=1000, return_inferencedata=True)