PyMC3 posterior prediction example; how to incorporate a matrix of betas?

Here is my (working) stab at working with the examples (again tabling that a LM is not “right” here).

I would be super interested in your feedback on this general framework setup for acquiring the different objects!

# generate Likert y; 1-7
obs_y = np.random.randint(7, size=(100,)) + 1
obs_y = pd.DataFrame(obs_y)

# generate treatment and control group
treatment = np.random.randint(3, size=(100,))
# convert to one-hot;
treatment = pd.get_dummies(data=treatment, drop_first=True)
treatment.columns = ['treat_{}'.format(x) for x in range(2)]

# generate many categorical controls
covariates = np.random.randint(2, size=(100, 56))
covariates = pd.DataFrame(covariates)  # to df
covariates.columns = ['var_{}'.format(x) for x in range(56)]

# coords for pymc3
coords = {
          'obs_id': np.arange(len(obs_y)),
          'treatment': treatment.columns,
          'covariates': covariates.columns
    }


# specify model
with pm.Model(coords=coords) as main_lm_model:
    # data containers
    covariates_data = pm.Data('covariates_data', covariates, dims=('obs_id', 'covariates'))
    treatment_data = pm.Data('treatment_data', treatment, dims=('obs_id', 'treatment'))

    # priors
    alpha = pm.Normal('alpha', mu=0, sd=1)
    covariates_betas = pm.Normal("covariates_betas", mu=0, sd=1, dims='covariates')
    treatment_betas = pm.Normal("treatment_betas", mu=0, sd=1, dims='treatment')

    # model error
    sigma = pm.HalfNormal("sigma", sigma=1)

    # matrix-dot products
    m1 = pm.math.matrix_dot(treatment, treatment_betas)
    m2 = pm.math.matrix_dot(covariates, covariates_betas)

    # expected value of y
    mu = alpha + m1 + m2

    # Likelihood: Normal
    y = pm.Normal("y",
                  mu=mu,
                  sigma=sigma,
                  observed=obs_y.values.reshape(-1,),
                  dims='obs_id')

    # set sampler
    step = pm.NUTS([alpha, treatment_betas, covariates_betas, sigma], target_accept=0.9)

# Inference button (TM)!
with main_lm_model:
    lm_trace = pm.sample(draws=1000,
                           step=step,
                           init='jitter+adapt_diag',
                           cores=4,
                           tune=500,  # burn in
                           return_inferencedata=True)  # seemingly have to do this to get idata?

# prior analysis
with main_lm_model:
    prior_pc = pm.sample_prior_predictive()

# posterior analysis
with main_lm_model:
    # ppc
    ppc = pm.fast_sample_posterior_predictive(trace=lm_trace,
                                              random_seed=1,
                                              # actually have to specify these
                                              # or you cant get them for calcs
                                              var_names=['y', 'alpha', 'treatment_betas', 'covariates_betas', 'sigma']
                                              )
    # turn to idata
    ppc = az.from_pymc3(posterior_predictive=ppc)

# seems like it already stacks chains + draws?
ppc.posterior_predictive['covariates_betas'].shape  # (1, 4000, 56)

# plot ppc
az.plot_ppc(data=ppc, num_pp_samples=100);

# try to calculate mu across the posterior like:
# https://docs.pymc.io/notebooks/posterior_predictive.html
treatment_mm = np.matmul(
                         # remove chain dim
                         np.squeeze(ppc.posterior_predictive['treatment_betas'].values),
                         lm_trace.constant_data['treatment_data'].T.values
                         )

covariate_mm = np.matmul(
                         # remove chain dim
                         np.squeeze(ppc.posterior_predictive['covariates_betas'].values),
                         lm_trace.constant_data['covariates_data'].T.values
                         )
# broadcast alpha
mu_pp = np.squeeze(ppc.posterior_predictive['alpha'].values)[:, None] + treatment_mm + covariate_mm
mu_pp.shape  # 4000 samples, 100 obs
mu_pp_mean = mu_pp.mean(0)  # get mean down dim 0

_, ax = plt.subplots()
ax.plot(treatment.iloc[:, 0], obs_y.values.reshape(-1), "o", ms=4, alpha=0.4, label="Data")
ax.plot(treatment.iloc[:, 0], mu_pp_mean, label="Mean outcome", alpha=0.6)
az.plot_hpd(
    treatment.iloc[:, 0],
    mu_pp,
    ax=ax,
    fill_kwargs={"alpha": 0.8, "label": "Mean outcome 94% HPD"},
)
az.plot_hpd(
    treatment.iloc[:, 0],
    ppc.posterior_predictive['y'],
    ax=ax,
    fill_kwargs={"alpha": 0.8, "color": "#a1dab4", "label": "Outcome 94% HPD"},
)