How to call predict function on BART?

I’m trying to figure out how to use the predict function with BART. I know this is very new, and the only example notebook I can find does not use it or do prediction on new data.

I started with the demo notebook and tried adding this cell after completing the trace:

with Model() as model:
BART.predict(BART, X_new=x_test)

But get the error: “AttributeError: type object ‘BART’ has no attribute ‘all_trees’”

Does anyone know how to correctly call this function for predicting on new X data?


1 Like

Hi @justindlwhite, there is a bug in the code, to make the predict method works you need to sample a single chain.

1 Like

Okay, I tried rerunning the code with trace = pm.sample(cores = 2, chains=1… but am still coming up with same error. Would it be possible to add an example of the predict method to the example notebook?

This should work.

with pm.Model() as model:
    μ = pm.BART('μ', X, Y)
    σ = pm.HalfNormal('σ', 1)

    y = pm.Normal('y', μ, σ, observed=Y)
    trace_u = pm.sample(2000, chains=1)

Improving the documentation, the API and the models BART is able to run is on my ToDo list.

1 Like

Thanks so much, that worked!

One clarifying question about this: when I call the .distribution.predict(X_new), is that returning the Y values or the tree distribution μ?

It will return the μ. That will change in the future to return Y.

Got it, so given your model above, how could I get an updated y value on the X_new?

for the moment you have to do it by hand, for example you can approximate it with something like

stats.norm(μ.distribution.predict(X_new), trace_u["σ"].mean())
1 Like

Hey, I know this thread is quite old, but I was wondering if you had an idea how to call predict on BART in the newer versions of pymc? I am getting these errors and I am not sure how to sample from the decision tree with new X values now. If you have any suggestions that would be really appreciated!

AttributeError                            Traceback (most recent call last)
Input In [148], in <cell line: 1>()
----> 1 bart.distribution

AttributeError: 'TensorVariable' object has no attribute 'distribution'
1 Like