Multilevel Malaise

Apologies in advance for the lengthy post, but I struggled to condense things any further.

I have several questions related to working with hierarchical multilevel models. I’ve included a full discussion with an example below, but summarizing my high level questions are:

  1. Sampling posterior predictive simulations seems to be much slower than directly calculating the prior predictive using xarray calcluations on the posterior. In this example the speed difference is ~1.7s vs. 20ms. Is this expected or am I doing something wrong.

  2. What are various approaches for fast model iteration? In my actual modelling workflow fitting models is taking on the order of 30 minutes which really slows down iteration, I’m wondering if trying to use ADVI for initial prototyping is recommended or if there are other approaches to consider.

  3. How should I think through building priors for multiple nested hierarchical levels.

    a. I have parameters which based on priors I want to restrict to being positive. However as I nest more levels restricting each hierarchical level to be positive shifts the final prior to the right in ways that are unintuitive.

    b. How should I think about the variance parameters when nesting multiple hierarchical levels? I have some intuition about how the means from one level feed into the next level but understanding how the variance of the various levels affects the model is unintuitive.

  4. Any general tips on model design and maintaining mappings between various levels in a multilevel model? I have taken the approach of using a helper function below (map_from_levels) but wondering if there are easier ways to do this kind of thing.

I’ve included an example below of some synthetic hierarchical data I am creating and a model I fit. I am using pymc version 4.2.1. I have referenced the above summary questions inline.

The structure is that there are various producers which can be categorized into different groups and which operate in different countries.

countries = ["Canada", "France"]
groups = ["A", "B", "C"]
producers = ["P1", "P2", "P3", "P4", "P5", "P6"]

The product they produce is affected by a country by group by producer level beta, and this beta has the following hierarchy: country → group → producer. You can think of this parameter like some elasticity to an external covariate X.

value = country_group_producer_param * X + rng.normal(0, 1, n)

The data is ragged in the sense that not all producers operate in all countries. Below is full code for generating data

import numpy as np
import pandas as pd
import pymc as pm
import arviz as az


def generate_params(seed):
    rng = np.random.default_rng(seed)
    base_affect = 5
    country_affect = {
        "Canada": rng.normal(base_affect, 2, 1),
        "France": rng.normal(base_affect, 2, 1),
    }
    country_group_affect = {
        ("Canada", "A"): rng.normal(country_affect["Canada"], 1, 1),
        ("Canada", "C"): rng.normal(country_affect["Canada"], 1, 1),
        ("France", "A"): rng.normal(country_affect["France"], 1, 1),
        ("France", "B"): rng.normal(country_affect["France"], 1, 1),
    }
    
    country_group_producer_affect = {
        ("Canada", "A", "P1"): rng.normal(country_group_affect[("Canada", "A")], 1, 1),
        ("Canada", "A", "P2"): rng.normal(country_group_affect[("Canada", "A")], 1, 1),
        ("Canada", "A", "P3"): rng.normal(country_group_affect[("Canada", "A")], 1, 1),
        ("Canada", "C", "P6"): rng.normal(country_group_affect[("Canada", "C")], 1, 1),
        ("France", "A", "P1"): rng.normal(country_group_affect[("France", "A")], 1, 1),
        ("France", "B", "P1"): rng.normal(country_group_affect[("France", "B")], 1, 1),
        ("France", "B", "P1"): rng.normal(country_group_affect[("France", "B")], 1, 1),
    }
    return country_group_producer_affect, country_group_affect, country_affect


def map_from_levels(from_lvl, to_lvl, data, from_lvl_index=None, to_lvl_index=None, array=True):
    """Helper function to map between levels"""
    from_lvl_index = from_lvl_index or f"{from_lvl}_index"
    to_lvl_index = to_lvl_index or f"{to_lvl}_index"
    mapping = data.sort_values(to_lvl_index).loc[:, [from_lvl_index, from_lvl, to_lvl, to_lvl_index]].drop_duplicates()
    if array:
        return mapping[from_lvl_index].values
    return mapping



def generate(n=100, seed=42):
    rng = np.random.default_rng(seed)
    groups = ["A", "B", "C"]
    producers = ["P1", "P2", "P3", "P4", "P5", "P6"]
    countries = ["Canada", "France"]
    producer_group = pd.DataFrame([
        ("P1", "A"),
        ("P2", "A"),
        ("P3", "A"),
        ("P4", "B"),
        ("P5", "B"),
        ("P6", "C"),
    ], columns=["producer", "group"])
    producer_country = pd.DataFrame([
        ("P1", "Canada"),
        ("P1", "France"),
        ("P2", "Canada"),
        ("P3", "Canada"),
        ("P4", "France"),
        ("P5", "France"),
        ("P6", "Canada"),
    ], columns=["producer", "country"])
    
    country_group_producer_affect, _, _ = generate_params(seed)

    X = rng.normal(0, 1, n)
    dates = [f"t{t}" for t in range(n)]


    producer_group_lookup = producer_group.set_index("producer")["group"].to_dict()

    data = []
    for levels, param in country_group_producer_affect.items():
        value = param * X + rng.normal(0, 1, n)
        c, g, p = levels
        data.append(pd.DataFrame({"country": c, "group": g, "producer": p, "X": X, "value": value, "date": dates}))

    data = pd.concat(data, axis=0, ignore_index=True).rename_axis(index="obs_id").reset_index()
    data = data.assign(
        country_group=data["country"] + "_" + data["group"],
        country_group_producer=data["country"] + "_" + data["group"] + "_" + data["producer"],
    )
    country_idx, _ = data.country.factorize(sort=True)
    country_group_idx, _ = data.country_group.factorize(sort=True)
    country_group_producer_idx, _ = data.country_group_producer.factorize(sort=True)
    data = data.assign(
        country_index=country_idx,
        country_group_index=country_group_idx,
        country_group_producer_index=country_group_producer_idx,
    )
    return data

data = generate()

The data looks like the following

display(data.head())
display(data.loc[:, ["country", "group", "producer"]].drop_duplicates())
display(data.loc[:, ["country_group", "country_group_index"]].drop_duplicates())
display(data.loc[:, ["country_group_producer", "country_group_producer_index"]].drop_duplicates())

And the true parameters look like the following

true_country_group_producer_mean, true_country_group_mean, true_country_mean = generate_params(42)
display(true_country_mean)
display(true_country_group_mean)
display(true_country_group_producer_mean)
{'Canada': array([5.60943416]), 'France': array([2.92003179])}
{('Canada', 'A'): array([6.35988536]),
 ('Canada', 'C'): array([6.54999888]),
 ('France', 'A'): array([0.9689966]),
 ('France', 'B'): array([1.61785228])}
{('Canada', 'A', 'P1'): array([6.48772576]),
 ('Canada', 'A', 'P2'): array([6.04364276]),
 ('Canada', 'A', 'P3'): array([6.3430842]),
 ('Canada', 'C', 'P6'): array([5.69695495]),
 ('France', 'A', 'P1'): array([1.84839457]),
 ('France', 'B', 'P1'): array([1.68388298])}

I fit the following model

_, country_coords = data.country.factorize(sort=True)
_, country_group_coords = data.country_group.factorize(sort=True)
_, country_group_producer_coords = data.country_group_producer.factorize(sort=True)


with pm.Model() as hierarchical_model3:
    hierarchical_model3.add_coord("country", country_coords, mutable=False)
    hierarchical_model3.add_coord("country_group", country_group_coords, mutable=False)
    hierarchical_model3.add_coord("country_group_producer", country_group_producer_coords, mutable=False)
    hierarchical_model3.add_coord("obs_id", data.obs_id, mutable=False)
    
    X1 = pm.MutableData("X1", data["X"], dims="obs_id")
    
    country_to_country_group_idx = map_from_levels("country", "country_group", data)
    country_group_to_country_group_producer_idx = map_from_levels("country_group", "country_group_producer", data)
    country_group_producer_to_obs_idx = data.country_group_producer_index.values
    
    base_mu = pm.Uniform("base_mu", lower=1, upper=10)
    country_mu = pm.Normal("country_mu", base_mu, sigma=1, dims="country")
    country_group_mu = pm.Normal("country_group_mu", mu=country_mu[country_to_country_group_idx], sigma=1, dims="country_group")
    country_group_producer_mu = pm.Normal("country_group_producer_mu", mu=country_group_mu[country_group_to_country_group_producer_idx], sigma=1, dims="country_group_producer")
    
    sigma = pm.Uniform("sigma", lower=1, upper=10)
    pdata = pm.Normal("product_data", mu=country_group_producer_mu[country_group_producer_to_obs_idx], sigma=sigma, dims=["obs_id"], observed=data["value"])
    
with hierarchical_model3:
    prior3 = pm.sample_prior_predictive()
    
with hierarchical_model3:
    posterior3 = pm.sample()

pm.model_to_graphviz(hierarchical_model3)

As highlighted in Q3a, looking at the priors as you traverse the levels you see the prior mean shift to the right. I understand why this is happening, as you nest and truncate distributions you will naturally shift the mean to the right, but I’m wondering what alternative modeling approaches exist for dealing with this. Ultimately in my problem I have a prior view on producers elasticities to some external covariate X and I want to pool estimates across various geographies and producer types.

az.summary(prior3.prior).filter(like="Canada", axis=0)

I then generate new covariate data and use this data to forecast

fcst_data = generate(n=10, seed=41)


_, country_coords = fcst_data.country.factorize(sort=True)
_, country_group_coords = fcst_data.country_group.factorize(sort=True)
_, country_group_producer_coords = fcst_data.country_group_producer.factorize(sort=True)


with pm.Model() as hierarchical_model3_fcst:
    hierarchical_model3_fcst.add_coord("country", country_coords, mutable=False)
    hierarchical_model3_fcst.add_coord("country_group", country_group_coords, mutable=False)
    hierarchical_model3_fcst.add_coord("country_group_producer", country_group_producer_coords, mutable=False)
    hierarchical_model3_fcst.add_coord("obs_id", fcst_data.obs_id, mutable=False)
    
    X1 = pm.MutableData("X1", fcst_data["X"], dims="obs_id")
    
    country_group_producer_to_obs_idx = fcst_data.country_group_producer_index.values
    country_group_producer_mu = pm.Flat("country_group_producer_mu", dims="country_group_producer")
    
    mu = pm.Deterministic("mu", country_group_producer_mu[country_group_producer_to_obs_idx] * X1, dims=["obs_id"])

As mentioned in point Q1 above, sampling the posterior predictive is quite slow, this takes about 1.7s (on a real example this takes me 30 minutes) vs calculating from the posterior directly

# slow sampling
posterior_pred3 = pm.sample_posterior_predictive(posterior3.posterior.drop(["base_mu", "country_mu", "country_group_mu", "sigma"]), model=hierarchical_model3_fcst, var_names=["mu"])
pred3 = posterior_pred3.posterior_predictive.mu.mean(dim=["chain", "draw"]).to_series()
pred3.index = pd.MultiIndex.from_frame(fcst_data.set_index("obs_id").loc[pred3.index, ["country_group_producer", "date"]])
pred3 = pred3.unstack(level="country_group_producer")


def forecast_hierarchical(posterior, model_data):
    country_group_producer_mu = posterior.country_group_producer_mu

    x_data = model_data.set_index(["country_group_producer", "date"])["X"]
    # how to properly broadcast across xarray
    res = country_group_producer_mu.mean(["chain", "draw"]).to_series() * x_data
    res = res.unstack(level="country_group_producer")
    return res

# fast using posterior directly
alt_pred3 = forecast_hierarchical(posterior3.posterior, fcst_data)

display(pred3)

1 Like