Kuschke-style fixed threshold cumulative ordinal probit model

Cool! You could take advantage of built-in support for coords and dims. I don’t think there’s anything inefficient about the concatenate, but others would know more. For the sampling speed part, there are a couple other NUTS implementations you might try:

Using the Jax backend, you can try NumPyro’s sampler:

import pymc.sampling_jax

with model:
    idata = pm.sampling_jax.sample_numpyro_nuts(chains=4, tune=1000, draws=1000, target_accept=0.8)

or BlackJAX:

import pymc.sampling_jax

with model:
    idata = pm.sampling_jax.sample_blackjax_nuts(chains=4, tune=1000, draws=1000, target_accept=0.8)

Using the Numba backend, there is nutpie:

import nutpie 

with model:
    compiled_model = nutpie.compile_pymc_model(model)
    idata = nutpie.sample(model, chains=4, tune=1000, draws=1000, target_accept=0.8)