My compute cluster only supports PyMC 5.19 using Python 3.12.4, and I need to fit a bivariate joint model with copulas and two quite complicated marginal models (~2k parameters).
It took ~5 hours for 5k warmups + 5k iterations running on an H100 GPU, which is 3x faster than running the same on a CPU. However, it’s still quite slow, and I still want to speed up the sampling time so my Bayesian workflow/model iteration can fail faster.
Any suggestions on how to speed up sampling? Even though Nutpie is what I use in CPU sampling and the latest v6 uses it as default, is it also better than NumPyro in lower versions in v5?
The following is my current sampling function:
def _sample_model(
model: Any,
draws: int,
tune: int,
seed: int,
chains: int = 4,
cores: int = 4,
max_treedepth: int = 12,
target_accept: float = 0.99,
nuts_sampler: str | None = None,
nutpie_backend: str | None = None,
sampling_device: str = "gpu",
chain_method: str | None = None,
postprocessing_backend: str | None = None,
jitter: bool = True,
thin: int = 1,
) -> Any:
if draws <= 0:
raise ValueError("--fit-draws must be positive when no posterior file is supplied.")
_configure_runtime_cache_dirs()
import pymc as pm
sampling = _resolve_sampling_config(
sampling_device=sampling_device,
chains=chains,
cores=cores,
target_accept=target_accept,
nuts_sampler=nuts_sampler,
nutpie_backend=nutpie_backend,
chain_method=chain_method,
postprocessing_backend=postprocessing_backend,
thin=thin,
max_treedepth=max_treedepth,
jitter=jitter,
)
print(f"Sampling configuration: {json.dumps(sampling, indent=2)}")
if sampling["sampling_device"] == "gpu":
_jax_runtime_info(require_gpu=True)
init_method = "jitter+adapt_diag" if jitter else "adapt_diag"
with model:
idata = pm.sample(
draws=draws,
tune=tune,
chains=sampling["chains"],
cores=sampling["cores"],
target_accept=sampling["target_accept"],
nuts_sampler=sampling["nuts_sampler"],
nuts_sampler_kwargs=sampling["nuts_sampler_kwargs"],
init=init_method,
progressbar=False,
random_seed=seed,
compute_convergence_checks=True,
return_inferencedata=True,
)
if sampling["thin"] > 1:
idata = idata.isel(draw=slice(None, None, sampling["thin"]))
return idata
Thanks!