How can I set the maximum tree depth for the NUTS method from the numpyro library?

@ricardoV94 @junpenglao How can I set the maximum tree depth for the NUTS method from the numpyro library?
The way described in the test file test_mcmc_external.py doesn’t work:

import pymc as pm
import numpy as np

with pm.Model():
        a = pm.Normal("a")
        idata = pm.sample(nuts_sampler = "numpyro",
                          target_accept = 0.99,
                          nuts = {"max_treedepth": 1},
                          random_seed = 1410)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)

and specifying something via the nuts_kwargs argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}.

Yeah, I don’t see any code path to pass any arguments other than target_accept, I guess you could go directly to the numpyro_sampler function and pass it as nuts_kwargs? pymc/jax.py at 261862d778910a09c5b61edcc66958519a86815e · pymc-devs/pymc · GitHub

CC @fonnesbeck

Opened an issue here: Numpyro sampler `nuts_kwargs` can't be passed from `sample` · Issue #6757 · pymc-devs/pymc · GitHub

1 Like

Yes, passing it directly to pymc.sampling.jax.sample_numpyro_nuts via the nuts_kwargs argument works. Thanks!