Predictions for out of sample indexed categorical variables

Hey folks, I’ve found that making predictions for out of sample indexed categorical variables is a huge pain point in my workflow and usually requires me to make a custom predict function for each model. That can get quite complicated and is a huge potential source of error for introducing bugs into my code (not to mention my coworkers hate to code review these).

The most common cause is nested hierarchies. Here’s an example where it pops up that I tried to make as simple as possible.

\alpha_{\text{vendor}} \sim \text{Normal}(\mu_{\text{vendor}}, \sigma_{\text{vendor}})\\ \mu_{\text{product}} = \alpha_{\text{vendor}} + \beta X\\ y \sim \text{Normal}(\mu_{\text{product}}, \sigma)

There are vendors, products specific to each vendor, and product level features to map to. Units of a product each have some outcome y and its a draw from a normal distribution where the mean is the product mean, mu_product

with pm.Model(coords=) as model:

    v_ = pm.MutableData("vendor", prod_df.vendor.cat.codes.values)
    p_ = pm.MutableData("product", df["product"].cat.codes.values)
    # product level features
    X_ = pm.MutableData("X", prod_df.X)
    
    mu_vendor = pm.Normal("mu_vendor", 0, 3)
    sig_vendor = pm.Exponential("sig_vendor", 1)
    a_vendor = pm.Normal("a_vendor", mu_vendor, sig_vendor, dims="vendor")
    

    beta = pm.Normal("beta", 0, 1)
    mu_product = pm.Deterministic("mu_product", a_vendor[v_] + beta*X_, dims="product")
    
    sig = pm.Exponential("sig", 1)
    y = pm.Normal("obs", mu_product[p_], sig, observed=df.y.values)
    trace = pm.sample()

I’m attaching a notebook with all of the code to generate this example as well and test out predictions.

Basically with that model formulation theres no way to make predictions on out of sample products or vendors without writing it out by hand - even having an “other” category wouldn’t work for vendors because of this line.

mu_product = pm.Deterministic("mu_product", a_vendor[v_] + beta*X_, dims="product")

Any ideas here on how to make predictions without writing a custom predict function for each model?

I’m curious if its possible to recursively iterate backwards through the aesara graph anytime some indexed unobserved RV is out of sample until you reach a node thats in sample, but I haven’t gotten that far yet. Or who knows, maybe theres a more obvious fix here I’m missing

(to use the notebook, change the extension to .ipynb)
PyMC out of sample predictions-revisited.py (88.9 KB)