Evaluating Hierarchical MMM Models: How to Correctly Compute Posterior Predictive Metrics?

Hey!

I’d like to compute model evaluation summaries for the hierarchical model example in the Example Gallery, but it doesn’t seem as straightforward as computing for a single, national model. For non-hierarchical models, we can use compute_summary_metrics, but for the hierarchical case it requires reshaping across (date, geo).

The code below runs but when I’ve tried this approach to my own custom model it seemed incorrect. Would love to hear your thoughts. Thank you!


from pymc_marketing.mmm.evaluation import compute_summary_metrics


hdi_prob = 0.89


# 1) Get posterior predictive samples

# Use the posterior_predictive group that's already in mmm.idata

y_pred = mmm.idata.posterior_predictive["y_original_scale"]




# 2) Stack chain and draw into a single sample dimension

y_pred_stacked = y_pred.stack(sample=("chain", "draw"))

print("After stacking chains:", y_pred_stacked.dims, "shape:", y_pred_stacked.shape)



# 3) Stack (date, geo) into a single observation dimension

y_pred_obs = y_pred_stacked.stack(obs=("date", "geo"))

# Drop the intermediate dimensions and keep only obs and sample

y_pred_final = y_pred_obs.transpose("obs", "sample")

print("Final y_pred dims:", y_pred_final.dims, "shape:", y_pred_final.shape)



# 4) Get true y values in the same order

# Make sure y_train is sorted the same way as the stacked dimensions

y_true = y_train.values  

print("y_true shape:", y_true.shape)



# 5) Verify alignment

assert len(y_true) == y_pred_final.sizes["obs"], (

f"Mismatch: y_true={len(y_true)}, y_pred obs dim={y_pred_final.sizes['obs']}"

)




# 6) Compute metrics

metrics_to_calculate = [

"r_squared",

"rmse",

"nrmse",

"mae",

"nmae",

"mape",

]




results = compute_summary_metrics(

    y_true=y_true,

    y_pred=y_pred_final,

    metrics_to_calculate=metrics_to_calculate,

    hdi_prob=hdi_prob,

)

# 7) Print results

print("\n" + "="*50)

print("MODEL EVALUATION METRICS")

print("="*50)

for metric, stats in results.items():

print(f"\n{metric.upper()}:")

for stat, value in stats.items():

print(f"  {stat}: {value:.4f}")