I am running on Google Colab, which has one CPU (as well as a GPU), and thus runs jobs sequentially.
During tuning, it runs the first 200 samples quickly. But on the second chain, it goes very slow and this often means the second chain takes much longer overall.
Some examples from two similar but different models running right now:
100%|██████████| 2000/2000 [2:43:47<00:00, 2.75s/it]
5%|▍ | 95/2000 [2:38:38<128:51:02, 243.50s/it]
100%|██████████| 2000/2000 [56:40<00:00, 1.28it/s]
5%|▌ | 105/2000 [45:24<25:20:38, 48.15s/it]
We can see in both of the above examples (running with tune=1000, samples=1000), the second chain has already taken up about the same time that it took the first chain to converge. Based on the parameter early_max_treedepth, I understand the 200 samples are “early tuning,” but why not do it also in the second chain?
Besides the efficiency considerations, I wonder if it suggests the second chain is not independent of the first.
Likely there is some problem (or sub-optimal parameterization) with your model that makes it difficult to sample: https://statmodeling.stat.columbia.edu/2008/05/13/the_folk_theore/
I will have the model and data soon on github, but the issue here is not the overall slowness (takes hours). Perhaps if it took seconds I wouldn’t mind.
Rather, it is the fact that the first chain seems to manage better because the first 200 iterations are generated quickly to give the chain a head start, sort to speak. In the above chains, the rate for the first 200 iterations of the chains was about 1 it/s. This does not happen for the second chain - they take several orders of magnitude longer. Nor does it in any way help the model, because it seems clear that generating the “head start” iterations did help the model converge quicker.
But why shouldn’t it happen for the second chain as well?
I feel it is a bug, or at least something that should be configurable.
As a hack you could generate the chains independently
traces = [pm.sample(chains=1) for _ in xrange(n_chains)]
But then do I have the diagnostics to compare the chains for convergence?
You should be able to just wrap them in a MultiTrace object (you may have to change the chain name
x.chain to disambiguate)
Usually, this is related to the start value of the chain. I dont know your model so I cant say why excatly, but it might be that your model has some local maximum, and some of them have a very different geometry than the others. Say your second chain started near a local maximum with geometry that very different than the rest (i.e., high curvature), NUTS would adapt to the local curvature resulting a small step size and large leapfrog steps.
One way you can try is to set init to
pm.sample(..., init='adapt_diag') so all chain starting from the same point, or supply starting value by hand
I am looking at the code and I wonder if it might be due to that early_max_tree_depth. I used tree depth of 25 so the difference between the early_max_tree_depth (8) and 25 is even greater than when using the default value. The NUTS sampler uses step.iter_count to determine if the early_max_tree_depth or normal max_tree_depth should be used. In parallel sampling, the NUTS step is pickled and thus in practice there are separate step objects for each chain. But in sequential sampling, the step object is initialized once and shared between the chains. This is “hacked” regarding step.tune so that the following chains start tuning again, but iter_count and any other state maintained by the step object appears to still be dependent on the previous chains. As a result, if given, say, tune=500, draws=500, then the first chain would start at iter_count = 0, the second at iter_count = 1000, and the third at iter_count = 2000, etc. They would each proceed from a different start point, but only the first would use early_max_tree_depth because only its iter_count < 200. (Also, why is this magic number 200 not configurable as well?) This would also apply for any other state maintained by the step method.
Hmm if that’s the case it would likely be a bug, but I think even in sequential chain each step is reinitialized.