Binary classification example using the recently added BART model

I’m trying to figure out how to use the recently added BART model for a binary classification task, using the breast cancer wisconsin dataset.

I use a train/test split for testing the predictive performance of the fitted model on unseen data, therefore I defined X as a shared variable:

>>> X_shared = theano.shared(X_train)

The shape of the datasets are as follows:

>>> X_train.shape, Y_train.shape, X_test.shape, Y_test.shape
((426, 30), (426,), (143, 30), (143,))

Inspired by the regression example provided by @aloctavodia, I tried a fairly simple model:

with pm.Model() as model:
    
    x = pm.BART('x', X_shared.get_value(), Y_train)
    y = pm.Bernoulli('y', p=pm.math.sigmoid(x), observed=Y_train)

    trace = pm.sample()

On plugging in the test data I noticed that the shape of the posterior wasn’t updated (still having the same number of samples from the training data):

>>> X_shared.set_value(X_test)

>>> with model:
...     ppc = pm.sample_posterior_predictive(trace)
>>> posterior = ppc.get('y')
>>> posterior.shape
(2000, 426)

The questions I’m struggling with are:

  • Do I use the BART model the correct way for classification? How could I improve my model?
    In comparison to using a RandomForestClassifier with default hyper-params (acc. 96%), the results of this model seem to be no better than random guessing (acc. 57%).
  • Why wasn’t the shape of the posterior updated to (2000, 143)?

In order to reproduce my case, here’s the complete Gist.

Many thanks in advance!

2 Likes

I will check your gist in the next few days, in the meantime for the moment is not posible to get out of sample predictions, but it is something that will be added soon.

2 Likes

but it is something that will be added soon.

@aloctavodia is there an issue I could follow along? happy to create one if not

There is an open PR adding out of sample predictions from BART. Although not posterior predictive sampling. https://github.com/pymc-devs/pymc3/pull/4310. I will check on that stalled PR again.

1 Like