Hmmm. This hack works but requiring re creating the index
with pm.Model() as m_fake_multiple:
x_pred = pm.Data('x_pred', tmp_dfs[0]['x'])
y_obs = pm.Data('y_obs', tmp_dfs[0]['y'])
#run_id = pm.Data('run_id', tmp_dfs[0]['id'])
## Instead create the index again
run_id = pm.Data('run_id', np.repeat([0, 1, 2, 3], repeats = [N/4, N/4, N/4, N/4]))
a = pm.Normal('a', mu = 0, sigma = 10, shape = 4)
b = pm.Normal('b', mu = 0, sigma = 10, shape = 1)
lp = a[run_id] + b*x_pred
_ = pm.Normal('y', mu = lp, sigma = 1, observed = y_obs)