First just a general note on broadcasting. Suppose you have a feature matrix X
of shape (n, k)
, and you draw a vector beta
with shape k
, along with an intercept term alpha
.
In the base case, we can just do alpha + X @ beta
to get the linear combination of the X
’s.
Now suppose we have a single grouping index, group_idx
, of shape n
, representing g
groups. We adjust beta to reflect this, by drawing one beta for each feature-group, like beta = pm.Normal('beta', 0, 1, dims=('groups', 'features')
.
Now beta has shape (g,k)
. If I index it beta[group_idx]
, it will become (n, k)
. Now we can make a linear combination alpha[group_idx] + (X * beta[group_idx]).sum(axis=-1)
The -1 axis is the feature axis, so this computation will multiply each row of X by the betas that correspond to the group to which the row belongs, then add them all up. This is exactly the same as X @ beta
back when we only had one group.
Finally, suppose we have multiple groups. Consider only pooling the intercept, so we’re back to the X @ beta
case, but we the intercept of each row to be determined by two factors: the group and the color. To accomplish this, we simply make two intercepts: intercept_group = pm.Normal('intercept_group', 0, 1, dims=['groups'])
and intercept_color = pm.Normal('intercept_color', 0, 1, dims=['colors']
. The intercept of each row will be the sum of the two, so we have:
mu = intercept_group[group_idx] + intercept_color[color_idx] + X @ beta
If you wanted now to have slopes that vary by each group, the logic is exactly the same. We will have beta_group
and beta_color
, of shapes (g,k)
and (c,k)
(where c is the number of colors), then we use the multiply-and-sum method to make the linear combination:
mu = intercept_group[group_idx] + intercept_color[color_idx] + (X * (beta_group[group_idx] + beta_color[color_idx])).sum(axis=-1)
This approach assumes an additive structure between the dimensions, which I guess you object to. Following your approach more closely implies a separate parameter for every group-color pair, which we can then fancy-index. For example, the intercept could be:
intercept = pm.Normal('intercept', 0, 1, dims=['group', 'color'])
So it’s shape (g,c)
. Fancy indexing with both index variables – intercept[group_idx, color_idx]
– results in a vector of length n
.
Note, however, that these are not the same models! My model groups the rows by color and estimates an intercept, then by group and estimates an intercept, then combines these to get a row intercept. This second model splits the rows into group-color groups, and estimates the intercept for each one separately.
For beta in this second case, you should see the pattern by now: fancy index from (g, c, k)
to (n, k)
, then multiply-and-sum:
betas = pm.Normal('betas', 0, 1, dims=['group', 'color', 'feature'])
mu = intercepts[group_idx, color_idx] + (X * betas[group_idx, color_idx]).sum(axis=-1)