Is my model setup in proper way? Dependent Dirichlet process (DDP)

Hi everyone,

I am currently implementing a Dependent Dirichlet Process (DDP) mixture model in PyMC, and I do really appreciate any feedback or guidance from the community. Data used are 12 clusters features and 5 covariates fetures.

I followed a Dirichlet Depdent Model (An extension of Dirichelt Process Mixture Model which include covariates) that I found in this paper (see section 3.4.2, Single-θ dependent Dirichlet processes), I am trying to follow the equations as closely as possible, the overall model setup as below (generate by AI tools):

# -------------------------
# Data & dimensions
# -------------------------
# X: n_samples × n_features (standardized consumption data)
# X_covariates_scaled: n_samples × n_covariates (standardized sociodemographic variables)
n_samples = X.shape[0]
n_features = X.shape[1]  # Number of features in consumption data (Y_i in Equation 9)
n_covariates = X_covariates_scaled.shape[1]  # Covariates for kernel smoothing (x in π_h(x))
K = 4  # Truncated number of mixture components (clusters)

with pm.Model() as ddp_model:

    # ---------------------------------------------
    # φ: cluster-specific location in covariate space
    # Used to define kernel function w_h(x)
    # Corresponds to φ_h in kernel function w_h(x) in V_h(x) = w_h(x) * V_h
    # ---------------------------------------------
    phi = pm.Normal("phi", mu=0, sigma=1, shape=(K, n_covariates))

    # ---------------------------------------------
    # Kernel weights: w_h(x) = exp(-||x - φ_h||² / τ)
    # This is the smoothing kernel that induces covariate dependence in V_h(x)
    # Corresponds to w_h(x) in the stick-breaking weights π_h(x)
    # ---------------------------------------------
    sq_dist = pm.math.sqr(X_covariates_scaled[:, None, :] - phi[None, :, :]).sum(axis=2)
    tau = pm.HalfNormal("tau", sigma=1.0)
    psi = pm.math.clip(pm.math.exp(-sq_dist / tau), 1e-8, 1.0)  # Ensures values in (0,1)

    # ---------------------------------------------
    # V_h ~ Beta(1,1): Base stick-breaking weights
    # These are the global (shared) stick-breaking proportions
    # Corresponds to V_h in π_h(x) = w_h(x) * V_h
    # ---------------------------------------------
    V = pm.Beta("V", alpha=1, beta=1, shape=K)

    # ---------------------------------------------
    # Stick-breaking process adapted to covariates:
    # π_1(x) = w_1(x) * V_1
    # π_h(x) = w_h(x) * V_h * ∏_{j=1}^{h-1}(1 - w_j(x) * V_j)
    # This gives π_h(x) that depends on covariates via kernel-smoothed weights
    # Corresponds directly to π_h(x) in Equation (9)
    # ---------------------------------------------
    remaining_stick = pm.math.concatenate([
        pm.math.ones((n_samples, 1)),
        pm.math.cumprod(1 - psi[:, :-1] * V[:-1], axis=1)
    ], axis=1)

    pi_raw = psi * V[None, :] * remaining_stick
    pi = pm.Deterministic("pi", pi_raw / (pi_raw.sum(axis=1, keepdims=True) + 1e-8))  # Normalized π_h(x)

    # ---------------------------------------------
    # Cluster-specific means and variances:
    # θ*_h = (μ_h, σ_h) ~ F₀ (e.g., Normal-Inverse-Gamma prior)
    # These are the "atoms" in the DDP mixture
    # Corresponds to θ*_h in Equation (9)
    # ---------------------------------------------
    mu = pm.Normal("mu", mu=0, sigma=5, shape=(K, n_features))
    sigma = pm.HalfNormal("sigma", sigma=2.0, shape=(K, n_features))

    # ---------------------------------------------
    # Mixture likelihood for each feature dimension
    # Y_i ∼ Σ_h π_h(x_i) δ_{θ*_h}
    # This defines the mixture model for each observed feature
    # ---------------------------------------------
    likelihoods = []
    for d in range(n_features):
        components_d = pm.Normal.dist(mu=mu[:, d], sigma=sigma[:, d], shape=K)
        likelihoods.append(
          pm.Mixture(
              name=f"likelihood_{d}",
              w=pi,
              comp_dists=components_d,
              observed=X[:, d]
          )
        )


I am running the above model on Google Colab now because I have no gpu in my laptop. The nuts_sampler selected is BlackJAX to utilize gpu:

with ddp_model:  # Reuse the same model

    # Sampling with GPU support via BlackJAX for trace1
    trace1 = pm.sample(
        tune=1000,
        draws=2000,
        chains=1,
        init='advi+adapt_diag',  # Try this initialization
        random_seed=1234,  # Different seed
        compute_convergence_checks=True,
        target_accept=0.9,
        nuts_sampler="blackjax",
        var_names=["mu", "sigma", "pi", "phi"] # Trace only specific parameters
    )

    # Save trace1 to disk
    az.to_netcdf(data=trace1, filename=trace_path+"trace1.nc")

I only can run single chain a time, so I rerun it using below code with different seed:


with ddp_model:  # Reuse the same model
    trace2 = pm.sample(
        tune=1000,
        draws=2000,
        chains=1,
        init='advi+adapt_diag',
        n_init=10000,  # Optional: longer ADVI init
        random_seed=5678,  # New seed!
        compute_convergence_checks=True,
        target_accept=0.9,
        nuts_sampler="blackjax",
        var_names=["mu", "sigma", "pi", "phi"]
    )

    # Save trace2
    az.to_netcdf(data=trace2, filename=trace_path + "trace2.nc")

The result showed convergence issues, the poor mixing in the trace plots.Here is an example plot for cluster 0 across 3 chains for all 12 features:

The traceplot shows autocorrelation issues also, and I think the chains were not really exploring, they seems to be stucked or not moving at all if puttting them altogether:

So my question:

Could this be due to something structural in the model setup, prior distributions issues, or is it more likely to be an inference backend or tuning issue with BlackJAX which I wrongly setup my model?