Open all NUTS kwargs for sampling with Numpyro

In the current implementation pymc.sampling_jax.sample_numpyro_nuts(), only some of the arguments for the Numpyro NUTS sampler are available to the user because a few are preset in the function:

if nuts_kwargs is None:
    nuts_kwargs = {}
nuts_kernel = NUTS(
    potential_fn=logp_fn,
    target_accept_prob=target_accept,
    adapt_step_size=True,
    adapt_mass_matrix=True,
    dense_mass=False,
    **nuts_kwargs,
)

I think it would be useful to make all of these arguments available through the PyMC interface, but before proposing a change I was curious if there was that there are reason they aren’t? If not, would this be a PR I could make?

Those are only being “preset” as defaults. They should all be available to be set explicitly by the user, no?

Unfortunately, no because if you pass in a argument in nuts_kwargs that is already preset, you get an error that looks something like this (example from trying to pass an argument to dense_mass):

 numpyro.infer.hmc.NUTS() got multiple values for keyword argument 'dense_mass'

Ah, I see. So arguments like adapt_step_size , adapt_mass_matrix, and dense_mass aren’t exposed for pass-through in sample_numpyro_nuts(). There may be reasons (that I am not aware of), but if you submit an issue, you can figure it out one way or another.

Great, will do! Thank you

Link to issue: Open all NUTS kwargs to user for sampling with Numpyro · Issue #6020 · pymc-devs/pymc · GitHub

2 Likes

Pull request: Pass user-provided NUTS kwargs to Numpyro by jhrcook · Pull Request #6021 · pymc-devs/pymc · GitHub

1 Like