Numpyro sampler

Hi,
I’m trying to use the numpyro sampler, but I get the error below.

Are there any recent working examples?
Or otherwise insight in what is wrong?

 173     trace =pm.sampling.jax.sample_numpyro_nuts(dXfit, chains=4)


File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/pymc/sampling/jax.py:611, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    608     raise ValueError(f"{nuts_sampler=} not recognized")
    610 tic1 = datetime.now()
--> 611 raw_mcmc_samples, sample_stats, library = sampler_fn(
    612     model=model,
    613     target_accept=target_accept,
    614     tune=tune,
    615     draws=draws,
    616     chains=chains,
    617     chain_method=chain_method,
    618     progressbar=progressbar,
    619     random_seed=random_seed,
    620     initial_points=initial_points,
    621     nuts_kwargs=nuts_kwargs,
    622 )
    623 tic2 = datetime.now()
    625 if idata_kwargs is None:

File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/pymc/sampling/jax.py:437, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
    429 nuts_kwargs.setdefault("dense_mass", False)
    431 nuts_kernel = NUTS(
    432     potential_fn=logp_fn,
    433     target_accept_prob=target_accept,
    434     **nuts_kwargs,
    435 )
--> 437 pmap_numpyro = MCMC(
    438     nuts_kernel,
    439     num_warmup=tune,
    440     num_samples=draws,
    441     num_chains=chains,
    442     postprocess_fn=None,
    443     chain_method=chain_method,
    444     progress_bar=progressbar,
    445 )
    447 map_seed = jax.random.PRNGKey(random_seed)
    448 if chains > 1:

File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/numpyro/infer/mcmc.py:382, in MCMC.__init__(self, sampler, num_warmup, num_samples, num_chains, thinning, postprocess_fn, chain_method, progress_bar, jit_model_args)
    380 self._cache = {}
    381 self._collection_params = {}
--> 382 self._set_collection_params()

File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/numpyro/infer/mcmc.py:514, in MCMC._set_collection_params(self, lower, upper, collection_size, phase)
    509 def _set_collection_params(
    510     self, lower=None, upper=None, collection_size=None, phase=None
    511 ):
    512     self._collection_params["lower"] = self.num_warmup if lower is None else lower
    513     self._collection_params["upper"] = (
--> 514         self.num_warmup + self.num_samples if upper is None else upper
    515     )
    516     self._collection_params["collection_size"] = collection_size
    517     self._collection_params["phase"] = phase

TypeError: unsupported operand type(s) for +: 'int' and 'Model'

Try to use pm.sample with pm.sample(nuts_sampler="numpyro") and make sure you’re on a recent version of pymc

1 Like