I think what you are looking for is a nested error term that are the same for each of the j group. I would do something like:
with Model() as net_model:
...
template_eps = Normal('template_eps', 0, sd=1, shape=10)
template = b[0]*var_one + b[1]*var_two + template_eps
est = template[which] + b[2]*var_three
...
and
with Model() as flat_model:
...
template_eps = Normal('template_eps', 0, sd=1, shape=10)
est = b[0]*var_one[which] + b[1]*var_two[which] + b[2]*var_three + template_eps[which]
...