Bug in warning treedepth?

I am picking up pymc again after a long hiatus. I am using pymc v5.1.1. The following warning function which can be found in stats/convergence.py seems to generate a lot of warnings for a simple test model.

At this point I don’t really understand the concept of treedepth so I don’t have a clue about what it is supposed to check. Hopefully my questions are nevertheless relevant.

def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]:
    ...
    for c in treedepth.chain:
        if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05:
            warnings.append(...)
    return warnings

My questions:

  1. It seems like sum(... ) / ... is a little slow to execute. Probably because it uses the pure python sum instead of .sum. I tried to replace the fraction by float(treedepth.sel(chain=c).mean()) and it gives the same number for my case.
  2. The message it generates will tell you that the maximum treedepth has been reached. But the actual check verifies that the MEAN treedepth is greater than 0.05. So which one is supposed to be correct?
  3. Could it be that is is supposed to be > 20 or so? Seems to me that the mean treedepth will always be > 1 so warning when it is > 0.05 makes little sense to me.

Seems like the same issues mentioned here.

CC @michaelosthege @aseyboldt

Thanks cluhmann, I have no idea why it didn’t turn up during my search.

Meanwhile I found this was last changed in the following PR : Collect sampler warnings only through stats by michaelosthege · Pull Request #6192 · pymc-devs/pymc · GitHub The last version before that was v4.2.2 so I installed that. It does not generate the warnings.

At that point the same warning seems to be located in pymc/step_methods/hmc/nuts.py and I added a print to make sure the code is actually executing in my case. But there it tells me n_treedepth / float(n_samples) so it sounds like “the fraction of whatever where the max treesize was reached” and if this fraction is more than 5% there is a warning (it is 0% in my case). Which sounds quite different from what v5.1.1 is doing.

I think this answers my third question in a way then.

3 Likes

Should be fixed in Release v5.1.2 · pymc-devs/pymc · GitHub

2 Likes