Max_steps with sampling_jax.sample_numpyro_nuts

I have a hierarchical regression model with two levels, countries and local areas. The country level has 84 different ids and there are 156.000 local areas. The total number of rows in the full dataset is about 2.500.000. When I fit the model to a subset of the data, all rows that has a local area id in a random sample of 12.000 of the 156.000 local areas, NUTS converges pretty quickly, but as I increase the number of local areas included in the subset, NUTS seems to get stuck with too small step sizes (around 10^-6). My interpretation is that when the number of parameters gets significantly larger than the default max number of steps (1024) NUTS gets in trouble finding a configuration of parameter values that are better than the current state, so it tries to decrease the step size. It seem NUTS does work as long as the number of parameters is below 10 * max number of steps. Does my interpretation of what happens sound reasonable?

If it is, I think the solution should be to increase the maximum number of steps per iteration. For the full set with 156.000 parameters, the maximum number of steps per iteration should perhaps be something like 15.000 instead of the default 1024. I have tried both setting max_steps=2048 and max_tree_size=11, but it seems these are ignored since the output I see never prints a step size higher than 1023.

15/500 [02:09<1:33:53, 11.62s/it, 1023 steps of size 6.63e-06. acc. prob=0.59]

Here is the function call I use for the sampling

my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=250, draws=250, target_accept=0.90, chains = 2, chain_method='sequential', idata_kwargs=dict(log_likelihood=False, max_tree_size=11, max_steps=2048))

I have no idea about the theoretical questions.

Those parameters should not be part of the idata_kwargs, but passed directly to the sample function. However this type of arguments were not being forwarded until very recently so if it’s not working yet in the latest release you’ll have to wait for the next one.

Thanks for your reply! I guess you refer to Pass user-provided NUTS kwargs to Numpyro by jhrcook · Pull Request #6021 · pymc-devs/pymc · GitHub

I looked over it and as I understood it, it only affects ’adapt_step_size’, ’adapt_mass_matrix’ and ’dense_mass’, but I’ll try it out.


my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=250, draws=250, target_accept=0.90, chains = 2, chain_method='sequential', idata_kwargs=dict(log_likelihood=False), nuts_kwargs=dict(max_tree_depth=12))

Allowed me to get more steps per iteration.

Should work for all parameters, but we only tested those explicitly.

It is not guarantee. You situation is most likely some parameters is/are stuck at difficult geometry, you should look at modifying the prior and inspect the posterior before increasing the number of steps.

1 Like