How can I set the maximum tree depth for the NUTS method from the numpyro library?

Yes, passing it directly to pymc.sampling.jax.sample_numpyro_nuts via the nuts_kwargs argument works. Thanks!