Prediction on Hierarchical model

Hi,

I am trying to predict on my test set from my hierarchical model, the train dataset has 104 data points per market segment and the test dataset has 13 data points per market segment. When I am trying to predict I en up just doing the posterior on the train data. I am guessing the pm.data_set is not set up properly, I am working with a data frame, do I need to change the format or anything else?

Thanks for your help!

with pm.Model() as bhmmm:
    
    coef_media  = pm.Gamma("coef_media", alpha=2.0, beta=0.5)
    alpha_shape_hyper = pm.TruncatedNormal('alpha_shape_hyper', mu=0, sigma=1, lower=0.5, upper=3)
    lambda_shape_hyper = pm.Beta('lambda_shape_hyper', alpha=2, beta=5)   
    
    coef_base = pm.Exponential("coef_lam2", lam=1)
    sigma_slope = pm.HalfNormal("sigma_slope", sigma=1.0)

     
    for geo_idx, market_segment in enumerate(final_data_scaled_train["market_segment"].unique()):
        X_ = final_data_scaled_train[final_data_scaled_train["market_segment"] == market_segment]
        channel_contributions = []
        
        intercept_sigma = pm.HalfNormal(f'intercept_sigma_{market_segment}', sigma=1.0)
        #intercept_rw = pm.GaussianRandomWalk(f'intercept_{market_segment}', sigma=intercept_sigma, shape=len(X_))


        for channel_idx, channel in enumerate(delay_and_sat_channels):
            channel_data = X_[channel].to_numpy()
            
            coef = pm.Deterministic(f"coef{channel}_{market_segment}", coef_media)
            adstock = pm.Deterministic(f'adstock_{channel}_{market_segment}', lambda_shape_hyper)
            alpha = pm.Deterministic(f'alpha_{channel}_{market_segment}',  alpha_shape_hyper)

            channel_contribution = pm.Deterministic(
                f"contribution_{channel}_{market_segment}",
                coef * logistic_saturation(geometric_adstock(channel_data,
                                                             alpha= adstock,
                                                             l_max=8, normalize=False).eval().flatten(),
                                           lam=alpha).eval(),
            )
            channel_contributions.append(channel_contribution)

        for control in control_variables:
            if control == 'Image_2022_09_03_Dummy' and market_segment == 'Image':
                #coef_trend = pm.Exponential(f"coef_{control}_{market_segment}", lam=0.05)
                coef_trend = pm.Gamma(f"coef_{control}_{market_segment}", alpha=2.0, beta=0.5)
                trend_contribution = pm.Deterministic(f"contribution_{control}_{market_segment}", coef_trend * X_[control].values)
                channel_contributions.append(trend_contribution)
            elif control != 'Image_2022_09_03_Dummy':
                coef2 = pm.Exponential(f"coef_{control}_{market_segment}", lam=coef_base)
                control_contribution = pm.Deterministic(f"contribution_{control}_{market_segment}", coef2 * X_[control].values)
                channel_contributions.append(control_contribution)
                
        noise = pm.HalfCauchy(f"noise_{market_segment}", beta=sigma_slope)

        sales = pm.Normal(
            f"sales_{market_segment}",
            mu=intercept_sigma + sum(channel_contributions),
            sigma=noise,
            observed=X_['y'].to_numpy(),
        )

with  bhmmm:
    trace = pm.sample(draws=10000, chains=4, tune=5000)
    
    for market_segment in final_data_scaled_test["market_segment"].unique():
        X_test = final_data_scaled_test[final_data_scaled_test["market_segment"] == market_segment]
        
        for channel in delay_and_sat_channels:
            channel_key = f'{channel}_{market_segment}'  
            if channel_key in bhmmm.named_vars: 
                pm.set_data({channel_key: X_test[channel].to_numpy()})
                
        for control in control_variables:
            control_key = f'{control}_{market_segment}'  
            if control_key in bhmmm.named_vars:  
                pm.set_data({control_key: X_test[control].to_numpy()})
                
        pp = pm.sample_posterior_predictive(trace, predictions=True)