Hi,
I’m trying to use the numpyro sampler, but I get the error below.
Are there any recent working examples?
Or otherwise insight in what is wrong?
173 trace =pm.sampling.jax.sample_numpyro_nuts(dXfit, chains=4)
File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/pymc/sampling/jax.py:611, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
608 raise ValueError(f"{nuts_sampler=} not recognized")
610 tic1 = datetime.now()
--> 611 raw_mcmc_samples, sample_stats, library = sampler_fn(
612 model=model,
613 target_accept=target_accept,
614 tune=tune,
615 draws=draws,
616 chains=chains,
617 chain_method=chain_method,
618 progressbar=progressbar,
619 random_seed=random_seed,
620 initial_points=initial_points,
621 nuts_kwargs=nuts_kwargs,
622 )
623 tic2 = datetime.now()
625 if idata_kwargs is None:
File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/pymc/sampling/jax.py:437, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
429 nuts_kwargs.setdefault("dense_mass", False)
431 nuts_kernel = NUTS(
432 potential_fn=logp_fn,
433 target_accept_prob=target_accept,
434 **nuts_kwargs,
435 )
--> 437 pmap_numpyro = MCMC(
438 nuts_kernel,
439 num_warmup=tune,
440 num_samples=draws,
441 num_chains=chains,
442 postprocess_fn=None,
443 chain_method=chain_method,
444 progress_bar=progressbar,
445 )
447 map_seed = jax.random.PRNGKey(random_seed)
448 if chains > 1:
File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/numpyro/infer/mcmc.py:382, in MCMC.__init__(self, sampler, num_warmup, num_samples, num_chains, thinning, postprocess_fn, chain_method, progress_bar, jit_model_args)
380 self._cache = {}
381 self._collection_params = {}
--> 382 self._set_collection_params()
File ~/anaconda3/envs/pymc_env/lib/python3.12/site-packages/numpyro/infer/mcmc.py:514, in MCMC._set_collection_params(self, lower, upper, collection_size, phase)
509 def _set_collection_params(
510 self, lower=None, upper=None, collection_size=None, phase=None
511 ):
512 self._collection_params["lower"] = self.num_warmup if lower is None else lower
513 self._collection_params["upper"] = (
--> 514 self.num_warmup + self.num_samples if upper is None else upper
515 )
516 self._collection_params["collection_size"] = collection_size
517 self._collection_params["phase"] = phase
TypeError: unsupported operand type(s) for +: 'int' and 'Model'