I am following this excellent introduction to DPMM, but I want to use the general scheme for a classification task and generalize it to run with multidimensional samples. This is the code,

```
import pymc as pm
from pytensor import tensor as tt
import seaborn as sns
import arviz as az
import numpy as np
data_in = np.asarray([[316, 271, 307, 322, 298, 268, 245, 268, 347, 356],
[142, 145, 161, 137, 150, 161, 188, 158, 139, 142]])
N = data_in.shape[0] #number of samples
J = data_in.shape[1] #number of features
K = N #number of sticks = maximum number of clusters ~= number of samples
# data_in shape is (J, K)
def stick_breaking(beta):
portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
return beta * portion_remaining
with pm.Model() as model:
alpha = pm.Gamma("alpha", 1.0, 1.0) #prior shape parameter
beta = pm.Beta("beta", 1.0, alpha, shape=(J,K)) #beta for stick-breaking process, takes alpha as hyperparameter
w = pm.Deterministic("w", stick_breaking(beta)) #Dirichlet distriburte variable using stick-breaking
tau = pm.Gamma("tau", 1.0, 1.0, shape=(J,K))
lambda_ = pm.Gamma("lambda_", 10.0, 1.0, shape=(J,K))
mu = pm.Normal("mu", 0, tau=lambda_ * tau, shape=(J,K)) #mean value of clusters, variance not explained
obs = pm.NormalMixture("obs", w, mu, tau=lambda_ * tau, observed=data_in) #, comp_shape = (J,K) #comp_shape = (J,K)
```

And I get this result,

```
ValueError: Input dimension mismatch. One other input has shape[1] = 10, but input[1].shape[1] = 20.
During handling of the above exception, another exception occurred:
...
Inputs values: ['not shown', 'not shown', array([[1]], dtype=int8), array([[0]], dtype=int8)] Outputs clients: [[Sum{axis=[1], acc_dtype=float64}(Elemwise{Composite}.0), InplaceDimShuffle{x,0,1}(Elemwise{Composite}.0)], [All(Elemwise{Composite}.1)], [All(Elemwise{Composite}.2)]]
HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'. HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
```

Tho original code works for for 1D data, so there must be some problem with shapes. But I cant find where. In NormalMixture function definition, it is stated that the comp_shape argument “should be different than the shape of the mixture distribution, with the last axis representing the number of components”, so I assume the last dimension should be K. I already checked using it as argument but have similar results.