Dirichlet process mixture model with 2d input

I am following this excellent introduction to DPMM, but I want to use the general scheme for a classification task and generalize it to run with multidimensional samples. This is the code,

import pymc as pm
from pytensor import tensor as tt
import seaborn as sns
import arviz as az
import numpy as np

data_in = np.asarray([[316, 271, 307, 322, 298, 268, 245, 268, 347, 356],
       [142, 145, 161, 137, 150, 161, 188, 158, 139, 142]]) 

N = data_in.shape[0] #number of samples
J = data_in.shape[1] #number of features
K = N #number of sticks = maximum number of clusters ~= number of samples

# data_in shape is (J, K)

def stick_breaking(beta):
    portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
    return beta * portion_remaining

with pm.Model() as model:
    alpha = pm.Gamma("alpha", 1.0, 1.0) #prior shape parameter
    beta = pm.Beta("beta", 1.0, alpha, shape=(J,K)) #beta for stick-breaking process, takes alpha as hyperparameter
    w = pm.Deterministic("w", stick_breaking(beta)) #Dirichlet distriburte variable using stick-breaking

    tau = pm.Gamma("tau", 1.0, 1.0, shape=(J,K))
    lambda_ = pm.Gamma("lambda_", 10.0, 1.0, shape=(J,K))
    mu = pm.Normal("mu", 0, tau=lambda_ * tau, shape=(J,K)) #mean value of clusters, variance not explained
    
    obs = pm.NormalMixture("obs", w, mu, tau=lambda_ * tau, observed=data_in) #, comp_shape = (J,K)  #comp_shape = (J,K) 

And I get this result,

ValueError: Input dimension mismatch. One other input has shape[1] = 10, but input[1].shape[1] = 20. 
During handling of the above exception, another exception occurred: 
...

Inputs values: ['not shown', 'not shown', array([[1]], dtype=int8), array([[0]], dtype=int8)] Outputs clients: [[Sum{axis=[1], acc_dtype=float64}(Elemwise{Composite}.0), InplaceDimShuffle{x,0,1}(Elemwise{Composite}.0)], [All(Elemwise{Composite}.1)], [All(Elemwise{Composite}.2)]] 

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'. HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Tho original code works for for 1D data, so there must be some problem with shapes. But I cant find where. In NormalMixture function definition, it is stated that the comp_shape argument “should be different than the shape of the mixture distribution, with the last axis representing the number of components”, so I assume the last dimension should be K. I already checked using it as argument but have similar results.

Does this work?

import pymc as pm
from pytensor import tensor as pt
import numpy as np

data_in = np.asarray([[316, 271, 307, 322, 298, 268, 245, 268, 347, 356],
                      [142, 145, 161, 137, 150, 161, 188, 158, 139, 142]])

N = data_in.shape[0]  # number of samples
J = data_in.shape[1]  # number of features
K = N  # number of sticks = maximum number of clusters ~= number of samples

def stick_breaking(beta):
    portion_remaining = pt.concatenate([[1], pt.extra_ops.cumprod(1 - beta)[:-1]], axis=0)
    return beta * portion_remaining

with pm.Model() as model:
    alpha = pm.Gamma("alpha", 1.0, 1.0)  # prior shape parameter
    beta = pm.Beta("beta", 1.0, alpha, shape=K)  # beta for stick-breaking process, takes alpha as hyperparameter
    w = pm.Deterministic("w", stick_breaking(beta))  # Dirichlet distributed variable using stick-breaking

    tau = pm.Gamma("tau", 1.0, 1.0, shape=K)
    lambda_ = pm.Gamma("lambda_", 10.0, 1.0, shape=K)
    mu = pm.Normal("mu", 0, tau=lambda_ * tau, shape=(J, K))  # mean value of clusters, variance not explained

    obs = pm.NormalMixture("obs", w, mu, tau=lambda_ * tau, observed=data_in.T)  # Note the transpose on data_in

Thank you fro the feedback! It doesn’t work either, when sampling,

with model:
    trace = pm.sample(1000, tune=2500, init="advi", target_accept=0.9)

it rises a different but probably related error (the previous error was also when sampling, my mistake). I really do not understand how shape is handled when generating the Dirichlet RV using the stick breaking process. Maybe there is the error. In any case, I couldn’t find instances of multidimensional clustering using DPMM

ValueError Traceback (most recent call last) 
...
ValueError: Incompatible Elemwise input shapes [(10, 2, 1), (1, 10, 2)]