I’m building a Multiclass GP model but I’m encountering problems with the Categorical likelihood. The following model samples with vanilla pymc.sample()
with pymc.Model(coords=coords) as location_model:
# η = pymc.Normal('η', mu=1.0, sigma=.2)\
η = 1
for location in factors:
μ = pymc.gp.mean.Constant(c=0)
# Noise Kernel
σ_ν= pymc.Normal(f'σ_ν_{location}',mu=2.0, sigma=.5)
κ_wn = pymc.gp.cov.WhiteNoise(σ_ν**2)
# Predictive Kernel RBF
ℓ= pymc.Normal(f'ℓ_{location}', mu=10.0, sigma=.2,shape=M)
κ_se = pymc.gp.cov.ExpQuad(M, ls=ℓ)
κ = κ_se+κ_wn
# κ_mlp = MultiLayerPerceptronKernel(M,variance=, bias_variance=, weight_variance=)
# Initialization of GPs
gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
_f= gp.prior(f'_f_{location}', X=X_train.values, reparameterize=False)
f = pymc.Deterministic('f', at.stack(*latent_functions).T)
p = pymc.Deterministic('p', at.nnet.softmax(f, axis=1))
y_obs=pymc.Categorical('y_obs', p=p, observed=Y_train.values[:,0])
I’ve had issues with pymc.Categorical
behaving oddly in the past. Specifically, it seems when p is a 2D array, shapes can’t be inferred properly except when observed is strictly a vector. Setting the also doesn’t work. The solution above works fine with NUTS but doesn’t work with JAX - it throws a concretization error.