Fail to predict on new/hold-out data with nested multilevel/hierarchical model

Hi!

Appreciate all help that contributes to solving this.

UPDATE: Still a problem though! Found the solution which practically feels like a bug. I am required to have the same input dimensions as the training data which is 1000 samples. Even if my test data is only 3 samples or 100 samples. The hold-out dimensions must be strictly like the training. No less, no more. Why?

TLDR; Input dimension mismatch for new data

I’ve come a far with PyMC and by gradually increasing the complexity of each iteration of model building up until this point.

To clarify:

  • pm.sample_numpyro_nuts() works
  • pm.sample_posterior_predictive() works in isolation
  • pm.sample_posterior_predictive() doesn’t work with new/unseen/hold-out data

The structure of my model is Global → Second Level → Third level → Output. See image.

Here is the code. Good to know is that I have a total of 1000 samples trained on. Excuse the different variable names as this is an artifact from the iterative model building. It has no effect on the sampling.

coords dictionary

coords = {
    'loc': df.index.values,
    'branch': np.unique(branch_code),
    'city': np.unique(city_code),
    'productline': np.unique(productline_code),
    'branch_city': np.arange(len(unique_combo_branch_city)),
    'city_productline': np.arange(len(unique_combo_city_productline))
}

Model below

with pm.Model(coords=coords) as multilevel_model_ts3:
    # Indices
    branch_cty_idx = pm.MutableData("branch_idx", branch_city_idx, dims="loc")
    city_productline_idx = pm.MutableData("city_idx", city_product_idx, dims="loc")
    y_mut = pm.MutableData("y_mut", y, dims="loc")
    t_mut = pm.MutableData("t_mut", t, dims="loc")
    tax_mut = pm.MutableData("tax_mut",  tax, dims="loc")
    ewallet_mut = pm.MutableData("ewallet_mut", ewallet_code, dims="loc")
    cash_mut = pm.MutableData("cash_mut", cash_code, dims="loc")
    
    # Global hyperparameters
    g_mu = pm.Normal("g_mu", mu=0, sigma=0.01)
    g_sd = pm.HalfNormal("g_sd", sigma=0.01)
    
    # Global intercept
    a = pm.Normal("alpha", mu=15, sigma=.001)

    # Branch level intercept
    branch_intercept = pm.Normal("branch_intercept", mu=15, sigma=0.1, dims="branch")

    # Country-level parameters
    mu_country = pm.Normal("mu_branch", mu=g_mu, sigma=g_sd, dims="branch")
    sigma_country = pm.HalfNormal("sigma_branch", sigma=0.1, dims="branch")
    
    # City level intercept
    city_intercept = pm.Normal("city_intercept", mu=15, sigma=.1, dims="branch_city")

    # City-level parameters
    mu_city = pm.Normal("mu_city", 
                        mu=mu_country[unique_combo_branch_city[:, 0]], 
                        sigma=sigma_country[unique_combo_branch_city[:, 0]], 
                        dims="branch_city")
    sigma_city = pm.HalfNormal("sigma_city", sigma=1, dims="branch_city")
    
    # City-Article level parameters
    mu_city_article = pm.Normal("mu_city_article", 
                                mu=mu_city[city_product_to_branch], 
                                sigma=sigma_city[city_product_to_branch], 
                                dims="city_productline")
    

    # Time effect at branch-city level
    beta_time_branch_city = pm.Normal("beta_time_branch_city", mu=0, sigma=.01, dims="branch_city")

    # Time effect
    beta_time_city_productline = pm.Normal("beta_time_city_productline", mu=0, sigma=5, dims="city_productline")

    # Tax 5% predictor
    kappa_tax_branch_city = pm.Normal("kappa_tax_branch_city", mu=0, sigma=2, dims="branch_city")

    kappa_tax_city_productline = pm.Normal("kappa_tax_city_productline", mu=kappa_tax_branch_city[city_product_to_branch], sigma=2, dims="city_productline")

    # Payment predictor
    ewallet_payment_branch_city = pm.Normal("ewallet_payment_branch_city", mu=0, sigma=5, dims="branch_city")
    cash_payment_branch_city = pm.Normal("cash_payment_branch_city", mu=0, sigma=5, dims="branch_city")

    intercepts = a + branch_intercept[branch_cty_idx] + city_intercept[branch_cty_idx] + mu_city_article[city_productline_idx]
    
    # Expected value
    mu = pm.Deterministic("mu", 
                          intercepts + 
                          beta_time_branch_city[branch_cty_idx] * t_mut +
                          beta_time_city_productline[city_productline_idx] * t_mut +
                          kappa_tax_branch_city[branch_cty_idx] * tax_mut +
                          kappa_tax_city_productline[city_productline_idx] * tax_mut +
                          ewallet_payment_branch_city[branch_cty_idx] * ewallet_mut +
                          cash_payment_branch_city[branch_cty_idx] * cash_mut, 
                          dims="loc")
    
    # Observation noise
    sigma = pm.HalfNormal("sigma", sigma=1)
    

    # Likelihood
    obs = pm.LogNormal("obs", mu=mu, sigma=sigma, observed=y_mut, dims="loc")

Sampling below

with multilevel_model_ts3:
    trace_ts2 = sample_numpyro_nuts(draws=1000, tune=1000, target_accept=0.95)

Posterior predictive below

with multilevel_model_ts3:
    posterior_predictive_ts2 = pm.sample_posterior_predictive(trace_ts2, var_names=["obs"], random_seed=1990)

Up until this part everything works smoothly. Then when executing below an input mismatch error arises.

with multilevel_model_ts3:
    pm.set_data({"y_mut": [0,0,0],
                 "t_mut": [0, .3, .5],
                 "tax_mut": [0, .3, .5],
                 "ewallet_mut": [0, 1, 0],
                 "cash_mut": [0, 0, 1]},
                 
                 coords={"loc": [1000, 1001, 1002],
                         "branch": [0, 0, 1],
                         "branch_city": [0, 0, 1],
                         "city": [0, 1, 0],
                         "productline": [3, 3, 5],
                         "city_productline": [10, 11, 15]})

    pp = pm.sample_posterior_predictive(trace_ts2, var_names=["obs"], random_seed=1990 )

The error

ValueError: Input dimension mismatch: (input[%i].shape[%i] = %lld, input[%i].shape[%i] = %lld)
Apply node that caused the error: Composite{(i2 + i3 + i4 + i5 + (i13 * i12) + (i11 * i12) + (i10 * i9) + (i8 * i9) + (i6 * i7) + (i0 * i1))}(AdvancedSubtensor1.0, cash_mut, ExpandDims{axis=0}.0, AdvancedSubtensor1.0, AdvancedSubtensor1.0, AdvancedSubtensor1.0, AdvancedSubtensor1.0, ewallet_mut, AdvancedSubtensor1.0, tax_mut, AdvancedSubtensor1.0, AdvancedSubtensor1.0, t_mut, AdvancedSubtensor1.0)
Toposort index: 11
Inputs types: [TensorType(float64, shape=(None,)), TensorType(int32, shape=(None,)), TensorType(float64, shape=(1,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(int32, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,))]
Inputs shapes: [(1000,), (3,), (1,), (1000,), (1000,), (1000,), (1000,), (3,), (1000,), (3,), (1000,), (1000,), (3,), (1000,)]
Inputs strides: [(8,), (4,), (8,), (8,), (8,), (8,), (8,), (4,), (8,), (8,), (8,), (8,), (8,), (8,)]
Inputs values: ['not shown', array([0, 0, 1]), array([14.99939276]), 'not shown', 'not shown', 'not shown', 'not shown', array([0, 1, 0]), 'not shown', array([0. , 0.3, 0.5]), 'not shown', 'not shown', array([0. , 0.3, 0.5]), 'not shown']
Outputs clients: [[lognormal_rv{"(),()->()"}(RNG(<Generator(PCG64) at 0x2AB45632C00>), MakeVector{dtype='int64'}.0, mu, ExpandDims{axis=0}.0)]]

I have followed pymc.model.core.set_data — PyMC dev documentation and General API quickstart — PyMC example gallery but I get the mismatch error.

This is a documneted “gotcha” when working with out-of-sample data. PyMC builds static computational graph representing your model, so shapes aren’t allow to just suddenly change. They are allowed to be dynamic, but they have to be explicitly declared as such. So you can connect the shape of the observed data to the shape of the input data that you plan to change during post-estimation tasks. See here for an example and discussion.

Alternatively if you use dims, and update those when you set the new data it will also work. Dims are another way of specifying “dynamic shapes”.

1 Like