Dynamic Slicing of Array for multi-index hierarchical model

I am trying to dynamically slice my hierarchical model to allow for new predictions on data of a different shape. The issue I am running into is I can’t find a way to slice my 2D tensor dynamically in the model stack to extract the necessary bits after the dot product. The below code triggers the same error I am running into that there are no arrays when doing advanced indexing.

After I compute my dot product, I only want to retain (or slice) out each element given the index in the groups.

I have tried to set the shape/pt.arange() component both as a pm.Deterministic and pm.MutabeData structure to no avail.

import numpy as np
import pymc as pm
import pytensor.tensor as pt
np.random.seed(123)
n = 5
variables = 3
groups = 2
weights = np.random.randn(variables, groups)
x = np.random.randn(n, variables)
group_index = np.random.choice(groups, n)
variable_names = ['var'+str(i) for i in range(variables)]
xw = np.dot(x,weights)
xw = xw[np.arange(xw.shape[0]),group_index]
y = xw + np.random.randn(n)/50
with pm.Model() as model:
    model.add_coord('obs', np.arange(n), mutable=True)
    model.add_coord('var_names', variable_names, mutable=False)
    model.add_coord('group_values', values=[0,1], mutable=False)
    data = pm.MutableData("data", x, dims=('obs','var_names'))
    sigma_variables = pm.HalfNormal("sigma_variables", sigma=1, dims = 'group_values')
    mu_variables = pm.Normal('mu_variables',mu=0, sigma=1, dims = 'group_values')
    var_weights = pm.Normal("var_weights", mu = mu_variables, sigma = sigma_variables, dims = ('var_names', 'group_values'))
    group_idx = pm.MutableData("group_idx", group_index, dims='obs')
    print(data.shape.eval()) #[5 3]
    print(var_weights.shape.eval()) #[3 2]
    mu = pt.dot(data,var_weights)
    print(mu.shape.eval()) #[5 2]
    mu = mu[pt.arange(mu.shape.eval()[0]),group_idx] # needs to be dynamic
    print(mu.shape.eval()) #[5]
    mu_d = pm.Deterministic('mu_d',mu)
    sigma = pm.HalfNormal('sigma', sigma=1)
    pm.Normal(LKHD_LBL, mu=mu_d, sigma=sigma, observed=y)
    
with model:
    trace = pm.sample(
        draws=100,
        tune=100,
        chains=2,
        target_accept=.99,
        random_seed=123,
    )
k = 2
new_data = np.random.randn(k, variables)
group_index = np.random.choice(groups, k)
with model:
    pm.set_data(
        coords={OBS_LBL: np.arange(k)},
        new_data={
            "data": new_data,
            "group_idx":group_index
        },
    )
    ppc_samples = pm.sample_posterior_predictive(
            trace,
            var_names=["mu_d"],
            random_seed=123,
        )

Well I have an answer, but I don’t particularly like it. Basically pass a new matrix/array of booleans where the Truth corresponds to the group label, apply the dot product and then take the diagonals of the new matrix. Feels kinda odd as I want to jointly estimate a bunch of things of the same shape and the solution might be easier to have three separate models.

np.random.seed(123)
n = 10
variables = 4
groups = 3
weights = np.random.randn(variables, groups)
x = np.random.randn(n, variables)
group_index = np.random.choice(groups, n)
variable_names = ['var'+str(i) for i in range(variables)]
xw = np.dot(x,weights)
p = xw[np.arange(xw.shape[0]),group_index]
y = p + np.random.randn(n)/50

zero_idx = np.where(group_index==0,1,0)
one_idx = np.where(group_index==1,1,0)
two_idx = np.where(group_index==2,1,0)
elewise_array = np.array([zero_idx,one_idx,two_idx]).T
p_dot = np.dot(xw, elewise_array.T)
print(np.diag(p_dot))
print(p)

with pm.Model() as model:
    model.add_coord('obs', np.arange(n), mutable=True)
    model.add_coord('var_names', variable_names, mutable=False)
    model.add_coord('group_values', values=['a','b','c'], mutable=False)
    
    data = pm.MutableData("data", x, dims=('obs','var_names'))
    idx_matrix = pm.MutableData("idx_matrix", elewise_array)
    group_idx = pm.MutableData("group_idx", group_index, dims='obs')
    
    sigma_variables = pm.HalfNormal("sigma_variables", sigma=1, dims = 'group_values')
    mu_variables = pm.Normal('mu_variables',mu=0, sigma=1, dims = 'group_values')
    var_weights = pm.Normal("var_weights", mu = mu_variables, sigma = sigma_variables, dims = ('var_names', 'group_values'))
    intercept = pm.Normal("intercept", mu=0, sigma=1, dims='group_values')
    
    xw = pt.dot(data,var_weights) # dot to get n by len(lvl) matrix
    xw_dot_idx = pt.dot(xw, idx_matrix.T) # dot with idx to get n by n
    xw_slice = pt.diag(xw_dot_idx) # strip diagonal elements

    mu = intercept[group_idx] + xw_slice
    mu_d = pm.Deterministic('mu_d',mu)
    sigma = pm.HalfNormal('sigma', sigma=1)
    pm.Normal(LKHD_LBL, mu=mu_d, sigma=sigma, observed=y)

I think you were already quite close to the solution in your first post. You are free to pass symbolic shapes to functions like pt.arange to get the fancy indexing (I think) you want. Here is what I tried:

coords = {
    'obs':np.arange(n),
    'var_names': variable_names,
    'group_values': [0, 1]
}

with pm.Model(coords=coords) as model:
    data = pm.Data("data", x, dims=('obs','var_names'))    
    group_idx = pm.Data("group_idx", group_index, dims='obs')
    
    sigma_variables = pm.HalfNormal("sigma_variables", sigma=1, dims = 'group_values')
    mu_variables = pm.Normal('mu_variables',mu=0, sigma=1, dims = 'group_values')
    var_weights = pm.Normal("var_weights", mu = mu_variables, sigma = sigma_variables, dims = ('var_names', 'group_values'))
    
    mu = (data @ var_weights)[pt.arange(data.shape[0]), group_idx]
    mu_d = pm.Deterministic('mu_d', mu)
    sigma = pm.HalfNormal('sigma', sigma=1)
    pm.Normal('LKHD_LBL', mu=mu_d, sigma=sigma, observed=y)

with model:
    idata = pm.sample(nuts_sampler='numpyro')

k = 2
new_data = np.random.randn(k, variables)
group_index = np.random.choice(groups, k)

with model:
    pm.set_data({"data": new_data, "group_idx":group_index}, coords={'obs': np.arange(k)})
    idata_pred = pm.sample_posterior_predictive(idata, var_names=["mu_d"], random_seed=123, extend_inferencedata=False, predictions=True)

Let me know if I misunderstood something

This works :joy: . Don’t know why I didn’t try data.shape[0]. My method still works, but will blow up if n is too big.

1 Like

As a rule of thumb just never call .eval() in your model and things should work out how you hope :slight_smile: