I am picking up pymc again after a long hiatus. I am using pymc v5.1.1. The following warning function which can be found in stats/convergence.py
seems to generate a lot of warnings for a simple test model.
At this point I don’t really understand the concept of treedepth so I don’t have a clue about what it is supposed to check. Hopefully my questions are nevertheless relevant.
def warn_treedepth(idata: arviz.InferenceData) -> List[SamplerWarning]:
...
for c in treedepth.chain:
if sum(treedepth.sel(chain=c)) / treedepth.sizes["draw"] > 0.05:
warnings.append(...)
return warnings
My questions:
- It seems like
sum(... ) / ...
is a little slow to execute. Probably because it uses the pure python sum instead of.sum
. I tried to replace the fraction byfloat(treedepth.sel(chain=c).mean())
and it gives the same number for my case. - The message it generates will tell you that the maximum treedepth has been reached. But the actual check verifies that the MEAN treedepth is greater than 0.05. So which one is supposed to be correct?
- Could it be that is is supposed to be > 20 or so? Seems to me that the mean treedepth will always be > 1 so warning when it is > 0.05 makes little sense to me.