Cool! You could take advantage of built-in support for coords and dims. I don’t think there’s anything inefficient about the concatenate
, but others would know more. For the sampling speed part, there are a couple other NUTS implementations you might try:
Using the Jax backend, you can try NumPyro’s sampler:
import pymc.sampling_jax
with model:
idata = pm.sampling_jax.sample_numpyro_nuts(chains=4, tune=1000, draws=1000, target_accept=0.8)
or BlackJAX:
import pymc.sampling_jax
with model:
idata = pm.sampling_jax.sample_blackjax_nuts(chains=4, tune=1000, draws=1000, target_accept=0.8)
Using the Numba backend, there is nutpie:
import nutpie
with model:
compiled_model = nutpie.compile_pymc_model(model)
idata = nutpie.sample(model, chains=4, tune=1000, draws=1000, target_accept=0.8)