PyMC3 posterior prediction example; how to incorporate a matrix of betas?

Thanks again for your patience and help!

I think I have come pretty far and I would be curious to see your thoughts; I think the last issue remaining is that I cannot quite figure out your last recommendation about the dimensions. Could you offer a quick code example? I tried playing with substituting treatment[:, None] in a few areas to no avail.

Setup

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import arviz as az
import theano
import xarray as xr
from sklearn.datasets import make_regression

X, target, coef = make_regression(n_samples=100,
                                  bias=1.0,
                                  n_informative=3,
                                  n_features=3,
                                  noise=2.5,
                                  random_state=1,
                                  coef=True)

target = target
treatment = X[:, 0]
covariates = X[:, 1:]

# coords
coords = {
          # dim 1: treatment var
          'treatment': ['treatment_1'],
          # dim 2: covariates
          'covariates': ['covariate_1', 'covariate_2'],
          # dim 3: len of df
          'obs_id': np.arange(len(target))
    }


# specify model
with pm.Model(coords=coords) as sk_lm:
    # data
    treatment_data = pm.Data('treatment_data', treatment, dims=('obs_id'))
    covariates_data = pm.Data('covariates_data', covariates, dims=('obs_id', 'covariates'))

    # priors
    alpha = pm.Normal('alpha', mu=0, sigma=1)
    treatment_beta = pm.Normal("treatment_beta", mu=44, sigma=2)  # really does not want dims=('treatment')
    covariates_betas = pm.Normal('covariates_betas', mu=[87, 58], sigma=2, dims=('covariates'))

    # model error
    sigma = pm.HalfNormal("sigma", sigma=1)

    # matrix-dot products
    m1 = pm.math.matrix_dot(covariates, covariates_betas)

    # expected value of y
    mu = alpha + (treatment_beta * treatment) + m1

    # Likelihood: Normal
    y = pm.Normal("y",
                  mu=mu,
                  sigma=sigma,
                  observed=target,
                  dims='obs_id')

    # set sampler
    step = pm.NUTS([alpha, treatment_beta, covariates_betas, sigma], target_accept=0.9)

    # Inference button (TM)!
    lm_trace = pm.sample(draws=1000,
                           step=step,
                           init='jitter+adapt_diag',
                           cores=4,
                           tune=500,  # burn in
                           return_inferencedata=False)

    # prior analysis
    prior_pc = pm.sample_prior_predictive()

    # posterior predictive
    ppc = pm.fast_sample_posterior_predictive(trace=lm_trace,
                                              random_seed=1,
                                              )
    # inference data
    lm_idata = az.from_pymc3(
                             trace=lm_trace,
                             prior=prior_pc,
                             posterior_predictive=ppc,
                             )

Posterior Analysis

# Posterior Analysis
post = lm_idata.posterior
# extract the data used to build the model
const = lm_idata.constant_data

# generate expected value of y given treatment; alpha + beta1
mu = (post["alpha"] + (post['treatment_beta'] * const['treatment_data']))  # (4, 1000, 100)

# argsort init
argsort1 = np.argsort(treatment)

# manually generate HDI information
hdi_data_yhat = az.hdi(lm_idata.posterior_predictive['y'])
hdi_data_eY = az.hdi(mu)

# Mean Outcome Plot
# start plot
_, ax = plt.subplots(figsize=(8,8))
# plot data
ax.plot(treatment, target, "o", ms=4, alpha=0.4, label="Data")
# mean outcome; # E(y) for one unit change in alpha + b1*x1
ax.plot(treatment, mu.mean(dim=("chain", "draw")), label="Mean outcome", alpha=0.6)

# plot 94% HPD for predicted y-hats; given trained model
# arviz.plot_hdi has superceded arvis.plot_hpd?
az.plot_hdi(x=treatment,  # observed treatment
            hdi_data=hdi_data_eY,  # expected 94% HPD for y-hat
            ax=ax,
            hdi_prob=0.94,
            fill_kwargs={"alpha": 0.8, "label": "Mean outcome 94% HPD"})

# plot 94% HPD for posterior possible unobserved y values
az.plot_hdi(x=treatment,  # observed treatment data
            hdi_data=hdi_data_yhat,  # set of possible unobserved y
            ax=ax,
            hdi_prob=0.94,
            fill_kwargs={"alpha": 1.0, "color": "#a1dab4", "label": "Outcome 94% HPD"})

ax.set_xlabel("Predictor")
ax.set_ylabel("Outcome")
ax.set_title("Posterior predictive checks")
ax.legend(ncol=2, fontsize=10);