Hi everyone,
I am currently implementing a Dependent Dirichlet Process (DDP) mixture model in PyMC, and I do really appreciate any feedback or guidance from the community. Data used are 12 clusters features and 5 covariates fetures.
I followed a Dirichlet Depdent Model (An extension of Dirichelt Process Mixture Model which include covariates) that I found in this paper (see section 3.4.2, Single-θ dependent Dirichlet processes), I am trying to follow the equations as closely as possible, the overall model setup as below (generate by AI tools):
# -------------------------
# Data & dimensions
# -------------------------
# X: n_samples × n_features (standardized consumption data)
# X_covariates_scaled: n_samples × n_covariates (standardized sociodemographic variables)
n_samples = X.shape[0]
n_features = X.shape[1] # Number of features in consumption data (Y_i in Equation 9)
n_covariates = X_covariates_scaled.shape[1] # Covariates for kernel smoothing (x in π_h(x))
K = 4 # Truncated number of mixture components (clusters)
with pm.Model() as ddp_model:
# ---------------------------------------------
# φ: cluster-specific location in covariate space
# Used to define kernel function w_h(x)
# Corresponds to φ_h in kernel function w_h(x) in V_h(x) = w_h(x) * V_h
# ---------------------------------------------
phi = pm.Normal("phi", mu=0, sigma=1, shape=(K, n_covariates))
# ---------------------------------------------
# Kernel weights: w_h(x) = exp(-||x - φ_h||² / τ)
# This is the smoothing kernel that induces covariate dependence in V_h(x)
# Corresponds to w_h(x) in the stick-breaking weights π_h(x)
# ---------------------------------------------
sq_dist = pm.math.sqr(X_covariates_scaled[:, None, :] - phi[None, :, :]).sum(axis=2)
tau = pm.HalfNormal("tau", sigma=1.0)
psi = pm.math.clip(pm.math.exp(-sq_dist / tau), 1e-8, 1.0) # Ensures values in (0,1)
# ---------------------------------------------
# V_h ~ Beta(1,1): Base stick-breaking weights
# These are the global (shared) stick-breaking proportions
# Corresponds to V_h in π_h(x) = w_h(x) * V_h
# ---------------------------------------------
V = pm.Beta("V", alpha=1, beta=1, shape=K)
# ---------------------------------------------
# Stick-breaking process adapted to covariates:
# π_1(x) = w_1(x) * V_1
# π_h(x) = w_h(x) * V_h * ∏_{j=1}^{h-1}(1 - w_j(x) * V_j)
# This gives π_h(x) that depends on covariates via kernel-smoothed weights
# Corresponds directly to π_h(x) in Equation (9)
# ---------------------------------------------
remaining_stick = pm.math.concatenate([
pm.math.ones((n_samples, 1)),
pm.math.cumprod(1 - psi[:, :-1] * V[:-1], axis=1)
], axis=1)
pi_raw = psi * V[None, :] * remaining_stick
pi = pm.Deterministic("pi", pi_raw / (pi_raw.sum(axis=1, keepdims=True) + 1e-8)) # Normalized π_h(x)
# ---------------------------------------------
# Cluster-specific means and variances:
# θ*_h = (μ_h, σ_h) ~ F₀ (e.g., Normal-Inverse-Gamma prior)
# These are the "atoms" in the DDP mixture
# Corresponds to θ*_h in Equation (9)
# ---------------------------------------------
mu = pm.Normal("mu", mu=0, sigma=5, shape=(K, n_features))
sigma = pm.HalfNormal("sigma", sigma=2.0, shape=(K, n_features))
# ---------------------------------------------
# Mixture likelihood for each feature dimension
# Y_i ∼ Σ_h π_h(x_i) δ_{θ*_h}
# This defines the mixture model for each observed feature
# ---------------------------------------------
likelihoods = []
for d in range(n_features):
components_d = pm.Normal.dist(mu=mu[:, d], sigma=sigma[:, d], shape=K)
likelihoods.append(
pm.Mixture(
name=f"likelihood_{d}",
w=pi,
comp_dists=components_d,
observed=X[:, d]
)
)
I am running the above model on Google Colab now because I have no gpu in my laptop. The nuts_sampler
selected is BlackJAX
to utilize gpu:
with ddp_model: # Reuse the same model
# Sampling with GPU support via BlackJAX for trace1
trace1 = pm.sample(
tune=1000,
draws=2000,
chains=1,
init='advi+adapt_diag', # Try this initialization
random_seed=1234, # Different seed
compute_convergence_checks=True,
target_accept=0.9,
nuts_sampler="blackjax",
var_names=["mu", "sigma", "pi", "phi"] # Trace only specific parameters
)
# Save trace1 to disk
az.to_netcdf(data=trace1, filename=trace_path+"trace1.nc")
I only can run single chain a time, so I rerun it using below code with different seed:
with ddp_model: # Reuse the same model
trace2 = pm.sample(
tune=1000,
draws=2000,
chains=1,
init='advi+adapt_diag',
n_init=10000, # Optional: longer ADVI init
random_seed=5678, # New seed!
compute_convergence_checks=True,
target_accept=0.9,
nuts_sampler="blackjax",
var_names=["mu", "sigma", "pi", "phi"]
)
# Save trace2
az.to_netcdf(data=trace2, filename=trace_path + "trace2.nc")
The result showed convergence issues, the poor mixing in the trace plots.Here is an example plot for cluster 0 across 3 chains for all 12 features:
The traceplot shows autocorrelation issues also, and I think the chains were not really exploring, they seems to be stucked or not moving at all if puttting them altogether:
So my question:
Could this be due to something structural in the model setup, prior distributions issues, or is it more likely to be an inference backend or tuning issue with BlackJAX which I wrongly setup my model?