Vectorisation of BART for multidimension hierarchical data

TL,DR: I’m trying to vectorise a BART model to be hierarchical, but I can’t get the coordinates/dimensions working

I have data that describes how a series of features give a variable response. The features space is large (n=34) with respect to the number of simulations (n=500) each with n=44 observations, making this an ill-posed problem. Nonetheless, I’m trying to build a BART model that can help me explore the effects of the features on the response.

The data frame looks as follows:
df.csv (2.2 MB)

realization y_riverflow x_month x_year x_alpha_io x_a_wl_io x_bio_hum_cn x_b_wl_io x_dcatch_dlai_io x_dqcrit_io x_retran_l_io x_retran_r_io x_r_grow_io x_rootd_ft_io x_sigl_io x_sorp x_tleaf_of_io x_tlow_io x_tupp_io x_l_vg_soil
0 1 0.147001 1 0 1.721491 0.504935 1.639253 1.111166 1.473436 1.925014 0.985717 1.307895 1.450263 2.372860 0.519003 0.376978 1.086885 1.061581 0.901236 0.246626
1 1 0.180318 2 0 1.721491 0.504935 1.639253 1.111166 1.473436 1.925014 0.985717 1.307895 1.450263 2.372860 0.519003 0.376978 1.086885 1.061581 0.901236 0.246626
2 1 0.267611 3 0 1.721491 0.504935 1.639253 1.111166 1.473436 1.925014 0.985717 1.307895 1.450263 2.372860 0.519003 0.376978 1.086885 1.061581 0.901236 0.246626
3 1 0.323996 4 0 1.721491 0.504935 1.639253 1.111166 1.473436 1.925014 0.985717 1.307895 1.450263 2.372860 0.519003 0.376978 1.086885 1.061581 0.901236 0.246626
4 1 0.361983 5 0 1.721491 0.504935 1.639253 1.111166 1.473436 1.925014 0.985717 1.307895 1.450263 2.372860 0.519003 0.376978 1.086885 1.061581 0.901236 0.246626
5275 10 0.166746 8 43 1.712617 0.990487 1.288422 1.771050 1.423148 1.198517 0.119383 4.858345 1.383394 0.920621 1.640639 1.951183 0.963208 1.030759 1.096198 0.218377
5276 10 0.147451 9 43 1.712617 0.990487 1.288422 1.771050 1.423148 1.198517 0.119383 4.858345 1.383394 0.920621 1.640639 1.951183 0.963208 1.030759 1.096198 0.218377
5277 10 0.132874 10 43 1.712617 0.990487 1.288422 1.771050 1.423148 1.198517 0.119383 4.858345 1.383394 0.920621 1.640639 1.951183 0.963208 1.030759 1.096198 0.218377
5278 10 0.117224 11 43 1.712617 0.990487 1.288422 1.771050 1.423148 1.198517 0.119383 4.858345 1.383394 0.920621 1.640639 1.951183 0.963208 1.030759 1.096198 0.218377
5279 10 0.195211 12 43 1.712617 0.990487 1.288422 1.771050 1.423148 1.198517 0.119383 4.858345 1.383394 0.920621 1.640639 1.951183 0.963208 1.030759 1.096198 0.218377

… in this case restricted to the first 10 realizations.

Here realization is the group level term: I expect the observations, y_riverflow, to be more similar within each realization than between realizations. The observations vary in time x_month and x_year, and then have 32 fixed parameters (the other x_ terms) that don’t vary in time, but do vary by realization.

I can formulate a basic (and poorly specified) BART model as follows:

x_cols = [col for col in df if col.startswith('x')]
y_cols = [col for col in df if col.startswith('y')]
y = df[y_cols].to_numpy().squeeze()
x = df[x_cols].to_numpy()

with pm.Model() as model: 
    w = pmb.BART("w", X=x, Y=y, m=200)
    
    sigma = pm.HalfNormal("sigma", 1.0)
    X = pm.Gamma("X", mu=w, sigma=sigma, observed=y)

    trace = pm.sample()

So for a hierarchical case, I assume you can’t define hyperpriors for the BART model, but attempting to vectorise the model, I’ve tried:

obs_idx = list(fdf.index)
rel_idx, rel = pd.factorize(fdf["realization"], sort=True)

coords={'obs_id': obs_idx,
       'realization': rel.to_numpy()}

with pm.Model(coords=coords) as model:
    realization_idx = pm.MutableData("realization_idx", rel_idx, dims="obs_id")
    
    w_rel = pmb.BART("w", X=x, Y=y, m=200, dims="realization")
    w = w_rel[realization_idx]
    
    sigma_rel = pm.HalfNormal("sigma", 1.0, dims="realization")
    sigma = sigma_rel[realization_idx]
    
    
    X = pm.Gamma("X", mu=w, sigma=sigma, observed=y, dims='obs_id')

    trace = pm.sample()

This also samples (slowly) but eventually fails with:

ValueError: conflicting sizes for dimension 'realization': length 5280 on the data but length 10 on coordinate 'realization'

So I’m not getting something quite right with the hierarchy dimensioning.

What is a better way to do this?
Thanks in advance for your advice!

CC @aloctavodia