Max_steps with sampling_jax.sample_numpyro_nuts


my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=250, draws=250, target_accept=0.90, chains = 2, chain_method='sequential', idata_kwargs=dict(log_likelihood=False), nuts_kwargs=dict(max_tree_depth=12))

Allowed me to get more steps per iteration.