Multi-class BART Model Assistance

The following example seems to work.

X = np.linspace(0, 1, 120)[:,None]
Y = np.repeat([0, 1, 2], 40)
plt.plot(X[:,0], Y, ".")

cat

with pm.Model() as bart_model:
    μ = pmx.BART("μ", X=X, Y=Y, m=50, shape=(3, 120))
    θ = pm.Deterministic('θ', pm.math.softmax(μ, axis=0))
    y = pm.Categorical("y", p=θ.T, observed=Y)
    idata = pm.sample()

Now we check the results

# get the posterior mean of θ
posterior_mean = az.extract_dataset(idata, var_names='θ').mean("sample")
#using the posterior_mean generate posterior_predictive samples (syntethic data)
new_y = [np.random.multinomial(n=1, 
         pvals=posterior_mean.sel({"θ_dim_1":i}).values).argmax() for i in range(120)]

plt.plot(X[:,0], new_y, ".")

cat_post

1 Like