Max_steps with sampling_jax.sample_numpyro_nuts

Should work for all parameters, but we only tested those explicitly.