Yes, passing it directly to pymc.sampling.jax.sample_numpyro_nuts via the nuts_kwargs argument works. Thanks!
pymc.sampling.jax.sample_numpyro_nuts
nuts_kwargs