I have a hierarchical regression model with two levels, countries and local areas. The country level has 84 different ids and there are 156.000 local areas. The total number of rows in the full dataset is about 2.500.000. When I fit the model to a subset of the data, all rows that has a local area id in a random sample of 12.000 of the 156.000 local areas, NUTS converges pretty quickly, but as I increase the number of local areas included in the subset, NUTS seems to get stuck with too small step sizes (around 10^-6). My interpretation is that when the number of parameters gets significantly larger than the default max number of steps (1024) NUTS gets in trouble finding a configuration of parameter values that are better than the current state, so it tries to decrease the step size. It seem NUTS does work as long as the number of parameters is below 10 * max number of steps. Does my interpretation of what happens sound reasonable?
If it is, I think the solution should be to increase the maximum number of steps per iteration. For the full set with 156.000 parameters, the maximum number of steps per iteration should perhaps be something like 15.000 instead of the default 1024. I have tried both setting max_steps=2048
and max_tree_size=11
, but it seems these are ignored since the output I see never prints a step size higher than 1023.
15/500 [02:09<1:33:53, 11.62s/it, 1023 steps of size 6.63e-06. acc. prob=0.59]
Here is the function call I use for the sampling
my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=250, draws=250, target_accept=0.90, chains = 2, chain_method='sequential', idata_kwargs=dict(log_likelihood=False, max_tree_size=11, max_steps=2048))