Similar to the thread How can I set the maximum tree depth for the NUTS method from the numpyro library? I’m wondering if there is a way to set the maximum tree depth for the NUTS method from the blackjax
library. Judging from blackjax.mcmc.nuts I think the corresponding keyword argument is max_num_doublings
. But passing it directly to pymc.sampling.jax.sample_blackjax_nuts
doesn’t work:
import pymc as pm
import numpy as np
m = pm.Model()
with m:
a = pm.Normal("a")
idata = pm.sample(nuts_sampler = "blackjax",
target_accept = 0.99,
nuts = {"max_treedepth": 1},
random_seed = 1410)
print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)
idata = pm.sampling.jax.sample_blackjax_nuts(
draws = 5000,
tune = 5000,
chains = 8,
target_accept = 0.99,
model = m,
nuts_kwargs = dict(max_num_doublings = 1)
)
print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(5, dtype=int64)
@ricardoV94 Do you see a any code path to do that?