Hi Pymc3 community,
I’m using the NUTS sampler on a large model, and its samples fine up until exactly 200 samples, and then it slows (from 1it/s to 40s/it). I’m not familiar with how NUTS is working under the hood, but does NUTS change its stepsize at 200 iters?
That 200 is consistent across multiple variations of the model, so I think this is something to do with the NUTS algorithm and not something idiosyncratic with my model. If so, are there any NUTS kwargs that could help?
Ah, i should have just read into NUTS more. There are 2 max_tree_depth parameters for NUTS, one for the first 200 samples, and one after. I had max_tree_depth set quite high (14), because i was getting warnings. After 200 iters, this higher depth kicked in, and slowed sampling.
I will try to fix the max_tree_depth warnings in other ways and keep these parameters at theirn defaults.
max_tree_depth at 14 gives you 16384 max leapfrog step, which is likely wayyyyyy overkill and pretty impractical. If you are hitting max_tree_depth warning, you should inspect your model parameterization, likely the scale of parameters are not optimal, or some narrow landscape geometric-wise in your postierior distribution.