Sampling prints "reached the maximum tree depth"

Hi again,

Testing with the Howell1 data, I’m getting the following message:

Chain <xarray.DataArray 'chain' ()>
array(0)
Coordinates:
    chain    int64 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.

First off - is the above a warning? If so it would be nice if it was such rather than a printed message.
Secondly, as a layman in pymc and bayesian stats, what action should I take here? I can’t see anywhere I should pass max_treedepth, nor what the value originally was. I don’t think I should re-parameterize my model - it’s pretty simple?

I’ve found several threads mentioning it, but I suspect that the message has become much more frequent in recent times, also where it didn’t use to appear. For instance, this example from 2021 doesn’t appear to print it, but if I copy the code and run it, it now does appear.

My code example, raw data here.

import pymc as pm
import arviz as az

import pandas as pd
df = pd.read_csv("Howell1.csv", sep=";")
df2 = df[df["age"] >= 18]

print(df2.head())

with pm.Model() as m:
    height_data = pm.Data("height_data", df2["height"] - df2["height"].mean(), mutable=True)

    a = pm.Normal(name="a", mu=60, sigma=10)
    b = pm.LogNormal(name="b", mu=0, sigma=1)
    
    mean = a + b * height_data
    sigma = pm.Uniform("sigma", lower=0, upper=10)

    weight_data = pm.Data("weight_data", df2["weight"], mutable=True)
    W = pm.Normal(
        name="weight", 
        mu=mean, 
        sigma=sigma,
        observed=weight_data
        )

    trace = pm.sample()

Full output:

    height     weight   age  male
0  151.765  47.825606  63.0     1
1  139.700  36.485807  63.0     0
2  136.525  31.864838  65.0     0
3  156.845  53.041914  41.0     1
4  145.415  41.276872  51.0     0
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, b, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
Chain <xarray.DataArray 'chain' ()>
array(0)
Coordinates:
    chain    int64 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(1)
Coordinates:
    chain    int64 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(2)
Coordinates:
    chain    int64 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(3)
Coordinates:
    chain    int64 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.

Looking at github, I see the message is a warning. It’s unclear to me that it is so from the above output. In particular, it’s missing the word “warning” :blush:

It is a warning, but I would not ignore it. That being said, your code runs without incident for me. What version of pymc are you using and what platform are you on?

On mobile right now, but this is the latest pip-install on macOS, M1 chip. In a clean 3.11 python venv.

I prefer to stick with pip, but can try with conda/mamba if you think that might help.

My inclination is to suggest using the “official” install method (i.e., conda/mamba). But I’m also not sure how installation could give rise to this particular issue. Have you tried other models, maybe even simpler models to see if anything samples with your installation?

I installed using conda on windows, and I get the same warnings for every model that I’ve tried since switching to pymc 4.0. These include very simple toy models copied from learning resources. When I try the same models in pymc3, I don’t get the warning.

Can either of you run this sort of model and see if it works?

import pymc as pm
with pm.Model() as model:
    a = pm.Normal("a")
    b = pm.Normal("b", mu=a, sigma=1, observed=[1,2,3,4,5])
    idata = pm.sample()

Yep, same behaviour:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 0 seconds.
Chain <xarray.DataArray 'chain' ()>
array(0)
Coordinates:
    chain    int64 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(1)
Coordinates:
    chain    int64 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(2)
Coordinates:
    chain    int64 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(3)
Coordinates:
    chain    int64 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameteriz

yeah, same here.

I have just installed mamba, fresh environment, installed pymc, and get the same error here as well:

$ which python
/Users/thomas/mambaforge/envs/pymc/bin/python

$ python test.py
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 0 seconds.% [8000/8000 00:00<00:00 Sampling 4 chains, 0 divergences]
Chain <xarray.DataArray 'chain' ()>
array(0)
Coordinates:
    chain    int64 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(1)
Coordinates:
    chain    int64 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(2)
Coordinates:
    chain    int64 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(3)
Coordinates:
    chain    int64 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.

More info has cropped up here.

Excellent! This is fixed in 5.1.2! Can confirm!

1 Like