In the current implementation pymc.sampling_jax.sample_numpyro_nuts()
, only some of the arguments for the Numpyro NUTS sampler are available to the user because a few are preset in the function:
if nuts_kwargs is None:
nuts_kwargs = {}
nuts_kernel = NUTS(
potential_fn=logp_fn,
target_accept_prob=target_accept,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
**nuts_kwargs,
)
I think it would be useful to make all of these arguments available through the PyMC interface, but before proposing a change I was curious if there was that there are reason they aren’t? If not, would this be a PR I could make?