How can I set the maximum tree depth for the NUTS method from the blackjax library?

Similar to the thread How can I set the maximum tree depth for the NUTS method from the numpyro library? I’m wondering if there is a way to set the maximum tree depth for the NUTS method from the blackjax library. Judging from blackjax.mcmc.nuts I think the corresponding keyword argument is max_num_doublings. But passing it directly to pymc.sampling.jax.sample_blackjax_nuts doesn’t work:

import pymc as pm
import numpy as np

m = pm.Model()
with m:
        a = pm.Normal("a")
        idata = pm.sample(nuts_sampler = "blackjax",
                          target_accept = 0.99,
                          nuts = {"max_treedepth": 1},
                          random_seed = 1410)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)

idata = pm.sampling.jax.sample_blackjax_nuts(
    draws = 5000,
    tune  = 5000,
    chains = 8,
    target_accept = 0.99,
    model = m,
    nuts_kwargs = dict(max_num_doublings = 1)
)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(5, dtype=int64)

@ricardoV94 Do you see a any code path to do that?