Dirichlet Gaussian Process Model - Suggestions for Improvement

I recently became interested in practical applications of Gaussian Processes, specifically applying GPs to Multiclass Classification problems. My initial model had K-GPs for K classes, that are passed through the softmax function to become a probability distribution. I initially placed a Categorical prior over observations. I then got the idea to swap out the categorical, with a dirichlet. On paper, this should be a superior approach, since it can allow for the detection of very high uncertainty predictions. Interestingly this model can also express % mixture ratios, which are types of data I tend to encounter often in my work. My problem is, the above Dirichlet-GP model exhibits very pathological behaviors in real-world data, it generally always defaults to random guessing (i.e. all samples have approx \frac{1}{K} for each element). Testing this proposed model with Iris dataset yields reasonably good results. Applying it to a dataset of cultivar mixtures, results in the above behavior.
The dataset has around 13 columns representing different cultivars and each observation is a mixture of these (possibly pure) i.e. an observation has the form (0.7, 0.3, 0.0, \dots). These seem to have reasonably strong correlations. For then inputs its around 500x63 all continuous. For comparison’s sake, I trained a relatively simply ANN, that seems to be performing quite well. My models look like this, and the inputs are standardized (for Dirichlet, I also offset zeros and ones by some small perturbation \epsilon=1e^{-4}):

# Dirichlet GP general model
with pymc.Model(coords=coords) as cultivar_model:
    _,M=X_train.shape
    factors=Y_train.columns
    n_factors=len(factors)
    gps=[]
    latent_functions=[]
    lower=1e-4
    if NOISE:
        σ_noise = pymc.TruncatedNormal('σ_noise', mu=3.0, sigma=1.0, shape=n_factors, lower=lower)
    λ = pymc.TruncatedNormal('λ', mu=3.0, sigma=.8 , shape=(n_factors,M), lower=lower)
    η = pymc.TruncatedNormal('η', mu=10.0, sigma=3.5, shape=n_factors, lower=lower)
    for idx,factor in enumerate(factors):
        C=0
        μ=pymc.gp.mean.Constant(c=C)
        κ_predictive = η[idx]**2*pymc.gp.cov.ExpQuad(M, ls=λ[idx,:])
        if NOISE:    
            κ_noise = pymc.gp.cov.WhiteNoise(σ_noise[idx])
            κ = κ_noise+κ_predictive
        else:
            κ = κ_predictive
        gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
        f_raw = gp.prior(f'f_raw_{factor}', x_train_tensor, reparameterize=False)
        gps.append(gp)
        latent_functions.append(f_raw)
    f=pymc.Deterministic('f', pytensor.tensor.stack(*latent_functions), dims='cultivars' )
    α =pymc.Deterministic('α', pymc.math.exp(f).T )
    y_obs = pymc.Dirichlet('y_obs', observed = y_train_tensor, a=α, shape=Y_train.shape[0])with pymc.Model(coords=coords) as cultivar_model:
    _,M=X_train.shape
    factors=Y_train.columns
    n_factors=len(factors)
    gps=[]
    latent_functions=[]
    lower=1e-4
    if NOISE:
        σ_noise = pymc.TruncatedNormal('σ_noise', mu=3.0, sigma=1.0, shape=n_factors, lower=lower)
    λ = pymc.TruncatedNormal('λ', mu=3.0, sigma=.8 , shape=(n_factors,M), lower=lower)
    η = pymc.TruncatedNormal('η', mu=10.0, sigma=3.5, shape=n_factors, lower=lower)
    for idx,factor in enumerate(factors):
        C=0
        μ=pymc.gp.mean.Constant(c=C)
        κ_predictive = η[idx]**2*pymc.gp.cov.ExpQuad(M, ls=λ[idx,:])
        if NOISE:    
            κ_noise = pymc.gp.cov.WhiteNoise(σ_noise[idx])
            κ = κ_noise+κ_predictive
        else:
            κ = κ_predictive
        gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
        f_raw = gp.prior(f'f_raw_{factor}', x_train_tensor, reparameterize=False)
        gps.append(gp)
        latent_functions.append(f_raw)
    f=pymc.Deterministic('f', pytensor.tensor.stack(*latent_functions), dims='cultivars' )
    α =pymc.Deterministic('α', pymc.math.exp(f).T )
    y_obs = pymc.Dirichlet('y_obs', observed = y_train_tensor, a=α, shape=Y_train.shape[0])
# Categorical GP
with pymc.Model(coords=coords) as cultivar_model:
    _,M=X_train.shape
    factors=Y_train.columns
    n_factors=len(factors)
    gps=[]
    latent_functions=[]
    lower=1e-4
    if NOISE:
        σ_noise = pymc.TruncatedNormal('σ_noise', mu=3.0, sigma=1.0, shape=n_factors, lower=lower)
    λ = pymc.TruncatedNormal('λ', mu=3.0, sigma=.8 , shape=(n_factors,M), lower=lower)
    η = pymc.TruncatedNormal('η', mu=10.0, sigma=3.5, shape=n_factors, lower=lower)
    for idx,factor in enumerate(factors):
        C=0
        μ=pymc.gp.mean.Constant(c=C)
        κ_predictive = η[idx]**2*pymc.gp.cov.ExpQuad(M, ls=λ[idx,:])
        if NOISE:    
            κ_noise = pymc.gp.cov.WhiteNoise(σ_noise[idx])
            κ = κ_noise+κ_predictive
        else:
            κ = κ_predictive
        gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
        f_raw = gp.prior(f'f_raw_{factor}', x_train_tensor, reparameterize=False)
        gps.append(gp)
        latent_functions.append(f_raw)
    f=pymc.Deterministic('f', pytensor.tensor.stack(*latent_functions), dims='cultivars' )
    α =pymc.Deterministic('α', pytensor.tensor.nnet.softmax(f, axis=0).T )
    y_obs = pymc.Dirichlet('y_obs', observed = y_train_tensor, a=α, shape=Y_train.shape[0])

Any ideas why this isn’t working? Is it something with my prior perhaps? I’ll note MCMC is pretty fast, but posterior sampling is not, it almost takes 2 hours.