@ricardoV94 @junpenglao How can I set the maximum tree depth for the NUTS method from the numpyro
library?
The way described in the test file test_mcmc_external.py doesn’t work:
import pymc as pm
import numpy as np
with pm.Model():
a = pm.Normal("a")
idata = pm.sample(nuts_sampler = "numpyro",
target_accept = 0.99,
nuts = {"max_treedepth": 1},
random_seed = 1410)
print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)
and specifying something via the nuts_kwargs
argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}
.