Problems with Categorical and JAX

I’m building a Multiclass GP model but I’m encountering problems with the Categorical likelihood. The following model samples with vanilla pymc.sample():

with pymc.Model(coords=coords) as location_model:
    gps=[]
    latent_functions=[]
    _,M=X_train.shape
    # η = pymc.Normal('η', mu=1.0, sigma=.2)\
    η = 1
    for location in factors:
        μ = pymc.gp.mean.Constant(c=0)
        # Noise Kernel
        σ_ν= pymc.Normal(f'σ_ν_{location}',mu=2.0, sigma=.5)
        κ_wn = pymc.gp.cov.WhiteNoise(σ_ν**2)
        # Predictive Kernel RBF
        ℓ= pymc.Normal(f'ℓ_{location}', mu=10.0, sigma=.2,shape=M)
        κ_se = pymc.gp.cov.ExpQuad(M, ls=ℓ)
        κ = κ_se+κ_wn        
        # κ_mlp = MultiLayerPerceptronKernel(M,variance=, bias_variance=, weight_variance=)
        
        # Initialization of GPs
        gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
        _f= gp.prior(f'_f_{location}', X=X_train.values, reparameterize=False)
        latent_functions.append(_f)
        gps.append(gp)
    f = pymc.Deterministic('f', at.stack(*latent_functions).T)
    p = pymc.Deterministic('p', at.nnet.softmax(f, axis=1))
    y_obs=pymc.Categorical('y_obs', p=p, observed=Y_train.values[:,0])

I’ve had issues with pymc.Categorical behaving oddly in the past. Specifically, it seems when p is a 2D array, shapes can’t be inferred properly except when observed is strictly a vector. Setting the also doesn’t work. The solution above works fine with NUTS but doesn’t work with JAX - it throws a concretization error.

Can you share the full error message?

Will do. As an update I “brute forced” the correct shapes, which are:

with pymc.Model(coords=coords) as location_model:
    gps=[]
    latent_functions=[]
    _,M=X_train.shape
    # η = pymc.Normal('η', mu=1.0, sigma=.2)
    η = 1e7
    for location in factors:
        μ = pymc.gp.mean.Constant(c=0)
        # Noise Kernel
        σ_ν= pymc.Normal(f'σ_ν_{location}',mu=2.0, sigma=.5)
        κ_wn = pymc.gp.cov.WhiteNoise(σ_ν**2)
        # Predictive Kernel: RBF
        # ℓ= pymc.Normal(f'ℓ_{location}', mu=15.0, sigma=4.8,shape=M)
        # κ_se = pymc.gp.cov.ExpQuad(M, ls=ℓ)
        # κ = κ_se+κ_wn
        # Alternatice Formulation
        # Predictive Kernel: MLP
        σ_b = pymc.Normal(f'σ_b_{location}', mu=10.0, sigma=1)
        σ_w = pymc.Normal(f'σ_w_{location}',mu=10.0, sigma=1)
        κ_mlp = MultiLayerPerceptronKernel(M,variance=η, bias_variance=σ_b, weight_variance=σ_w)
        κ = κ_mlp+κ_wn
        # κ = pymc.gp.cov.Linear(c=1)*κ_mlp+κ_wn
        # Initialization of GPs
        gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
        _f= gp.prior(f'_f_{location}', X=X_train.values, reparameterize=False)
        latent_functions.append(_f)
        gps.append(gp)
    f = pymc.Deterministic('f', at.stack(*latent_functions))
    p = pymc.Deterministic('p', at.nnet.softmax(f, axis=0))
    y_obs=pymc.Categorical('y_obs', p=p, observed=Y_train.values)

I think both samplers should work the same however