Is my model setup in proper way? Dependent Dirichlet process (DDP)

Hi everyone,

I am currently implementing a Dependent Dirichlet Process (DDP) mixture model in PyMC, and I do really appreciate any feedback or guidance from the community. Data used are 12 clusters features and 5 covariates fetures.

I followed a Dirichlet Depdent Model (An extension of Dirichelt Process Mixture Model which include covariates) that I found in this paper (see section 3.4.2, Single-θ dependent Dirichlet processes), I am trying to follow the equations as closely as possible, the overall model setup as below (generate by AI tools):

# -------------------------
# Data & dimensions
# -------------------------
# X: n_samples × n_features (standardized consumption data)
# X_covariates_scaled: n_samples × n_covariates (standardized sociodemographic variables)
n_samples = X.shape[0]
n_features = X.shape[1]  # Number of features in consumption data (Y_i in Equation 9)
n_covariates = X_covariates_scaled.shape[1]  # Covariates for kernel smoothing (x in π_h(x))
K = 4  # Truncated number of mixture components (clusters)

with pm.Model() as ddp_model:

    # ---------------------------------------------
    # φ: cluster-specific location in covariate space
    # Used to define kernel function w_h(x)
    # Corresponds to φ_h in kernel function w_h(x) in V_h(x) = w_h(x) * V_h
    # ---------------------------------------------
    phi = pm.Normal("phi", mu=0, sigma=1, shape=(K, n_covariates))

    # ---------------------------------------------
    # Kernel weights: w_h(x) = exp(-||x - φ_h||² / τ)
    # This is the smoothing kernel that induces covariate dependence in V_h(x)
    # Corresponds to w_h(x) in the stick-breaking weights π_h(x)
    # ---------------------------------------------
    sq_dist = pm.math.sqr(X_covariates_scaled[:, None, :] - phi[None, :, :]).sum(axis=2)
    tau = pm.HalfNormal("tau", sigma=1.0)
    psi = pm.math.clip(pm.math.exp(-sq_dist / tau), 1e-8, 1.0)  # Ensures values in (0,1)

    # ---------------------------------------------
    # V_h ~ Beta(1,1): Base stick-breaking weights
    # These are the global (shared) stick-breaking proportions
    # Corresponds to V_h in π_h(x) = w_h(x) * V_h
    # ---------------------------------------------
    V = pm.Beta("V", alpha=1, beta=1, shape=K)

    # ---------------------------------------------
    # Stick-breaking process adapted to covariates:
    # π_1(x) = w_1(x) * V_1
    # π_h(x) = w_h(x) * V_h * ∏_{j=1}^{h-1}(1 - w_j(x) * V_j)
    # This gives π_h(x) that depends on covariates via kernel-smoothed weights
    # Corresponds directly to π_h(x) in Equation (9)
    # ---------------------------------------------
    remaining_stick = pm.math.concatenate([
        pm.math.ones((n_samples, 1)),
        pm.math.cumprod(1 - psi[:, :-1] * V[:-1], axis=1)
    ], axis=1)

    pi_raw = psi * V[None, :] * remaining_stick
    pi = pm.Deterministic("pi", pi_raw / (pi_raw.sum(axis=1, keepdims=True) + 1e-8))  # Normalized π_h(x)

    # ---------------------------------------------
    # Cluster-specific means and variances:
    # θ*_h = (μ_h, σ_h) ~ F₀ (e.g., Normal-Inverse-Gamma prior)
    # These are the "atoms" in the DDP mixture
    # Corresponds to θ*_h in Equation (9)
    # ---------------------------------------------
    mu = pm.Normal("mu", mu=0, sigma=5, shape=(K, n_features))
    sigma = pm.HalfNormal("sigma", sigma=2.0, shape=(K, n_features))

    # ---------------------------------------------
    # Mixture likelihood for each feature dimension
    # Y_i ∼ Σ_h π_h(x_i) δ_{θ*_h}
    # This defines the mixture model for each observed feature
    # ---------------------------------------------
    likelihoods = []
    for d in range(n_features):
        components_d = pm.Normal.dist(mu=mu[:, d], sigma=sigma[:, d], shape=K)
        likelihoods.append(
          pm.Mixture(
              name=f"likelihood_{d}",
              w=pi,
              comp_dists=components_d,
              observed=X[:, d]
          )
        )


I am running the above model on Google Colab now because I have no gpu in my laptop. The nuts_sampler selected is BlackJAX to utilize gpu:

with ddp_model:  # Reuse the same model

    # Sampling with GPU support via BlackJAX for trace1
    trace1 = pm.sample(
        tune=1000,
        draws=2000,
        chains=1,
        init='advi+adapt_diag',  # Try this initialization
        random_seed=1234,  # Different seed
        compute_convergence_checks=True,
        target_accept=0.9,
        nuts_sampler="blackjax",
        var_names=["mu", "sigma", "pi", "phi"] # Trace only specific parameters
    )

    # Save trace1 to disk
    az.to_netcdf(data=trace1, filename=trace_path+"trace1.nc")

I only can run single chain a time, so I rerun it using below code with different seed:


with ddp_model:  # Reuse the same model
    trace2 = pm.sample(
        tune=1000,
        draws=2000,
        chains=1,
        init='advi+adapt_diag',
        n_init=10000,  # Optional: longer ADVI init
        random_seed=5678,  # New seed!
        compute_convergence_checks=True,
        target_accept=0.9,
        nuts_sampler="blackjax",
        var_names=["mu", "sigma", "pi", "phi"]
    )

    # Save trace2
    az.to_netcdf(data=trace2, filename=trace_path + "trace2.nc")

The result showed convergence issues, the poor mixing in the trace plots.Here is an example plot for cluster 0 across 3 chains for all 12 features:

The traceplot shows autocorrelation issues also, and I think the chains were not really exploring, they seems to be stucked or not moving at all if puttting them altogether:

So my question:

Could this be due to something structural in the model setup, prior distributions issues, or is it more likely to be an inference backend or tuning issue with BlackJAX which I wrongly setup my model?

I can’t quite follow all your code as I have no idea what clip does, but what you are seeing with your phi coefficients is that they’re overparameterized. A simplex with N components only has N - 1 degrees of freedom (and hence N - 1 dimensions in the topological sense) because the last value has to be equal to one minus the sum of the previous values. In my experience, most of the problems you get with multinomial sampling (including those nested in Dirichlet) is identifiability of this kind.

You see the lack of identifiability in your construction of pi using a deterministic many-to-one function. If you instead pin one of the dimensions to zero following the classical approach, then everything is identified and you will find sampling will go a lot faster. On the other hand, it makes the prior asymmetric with one of the values pinned.

Your pi is being constructed to actually violate the sum-to-zero with the addition of 1e-8. I guess PyMC isn’t going to gripe at that. To get this to be a proper simplex, you need to add to pi_raw in both numerator and denominator—then it’d be equivalent to having a super-weak Dirichlet prior for smoothing.

I have a system prompt installed to try to stop ChatGPT from generating doc on every line, which is a terrible practice for code readability and maintainability and goes against every doc guide I have ever seen.