Hierarchical model with copula on GPU: NumPyro or nutpie using PyMC 5.19?

My compute cluster only supports PyMC 5.19 using Python 3.12.4, and I need to fit a bivariate joint model with copulas and two quite complicated marginal models (~2k parameters).

It took ~5 hours for 5k warmups + 5k iterations running on an H100 GPU, which is 3x faster than running the same on a CPU. However, it’s still quite slow, and I still want to speed up the sampling time so my Bayesian workflow/model iteration can fail faster.

Any suggestions on how to speed up sampling? Even though Nutpie is what I use in CPU sampling and the latest v6 uses it as default, is it also better than NumPyro in lower versions in v5?

The following is my current sampling function:

def _sample_model(

model: Any,

draws: int,

tune: int,

seed: int,

chains: int = 4,

cores: int = 4,

max_treedepth: int = 12,

target_accept: float = 0.99,

nuts_sampler: str | None = None,

nutpie_backend: str | None = None,

sampling_device: str = "gpu",

chain_method: str | None = None,

postprocessing_backend: str | None = None,

jitter: bool = True,

thin: int = 1,

    ) -> Any:

if draws <= 0:

raise ValueError("--fit-draws must be positive when no posterior file is supplied.")

    _configure_runtime_cache_dirs()

import pymc as pm

    sampling = _resolve_sampling_config(

sampling_device=sampling_device,

chains=chains,

cores=cores,

target_accept=target_accept,

nuts_sampler=nuts_sampler,

nutpie_backend=nutpie_backend,

chain_method=chain_method,

postprocessing_backend=postprocessing_backend,

thin=thin,

max_treedepth=max_treedepth,

jitter=jitter,

    )

print(f"Sampling configuration: {json.dumps(sampling, indent=2)}")

if sampling["sampling_device"] == "gpu":

        _jax_runtime_info(require_gpu=True)

    init_method = "jitter+adapt_diag" if jitter else "adapt_diag"

with model:

        idata = pm.sample(

draws=draws,

tune=tune,

chains=sampling["chains"],

cores=sampling["cores"],

target_accept=sampling["target_accept"],

nuts_sampler=sampling["nuts_sampler"],

nuts_sampler_kwargs=sampling["nuts_sampler_kwargs"],

init=init_method,

progressbar=False,

random_seed=seed,

compute_convergence_checks=True,

return_inferencedata=True,

        )

if sampling["thin"] > 1:

        idata = idata.isel(draw=slice(None, None, sampling["thin"]))

return idata

Thanks!

The usual suspect in slow fitting is the model itself and how it’s coded or its match to the data. Andrew Gelman calls this the “folk theorem of statistical computing.”

I’d also be suspicious that the GPU code on an H100 is only 3 times as fast as CPU code. Is it having to bounce in-and-out of kernel or are you using a Blackjax-based sampler that’s itself coded in JAX?