Pymc5 out of sample

When debugging models it’s useful to use pm.model_to_graphviz to check for any obvious problems. In this case, the second model shows this graph:

As you can see, control_data and control_betas are not connected to anything. This is because the tt.set_subtensor line is not being saved, it should read contributions = tt.set_subtensor(contributions[:, n_features:], Z @ control_betas)

Looking at this graph and playing with it lead me to re-write the model slightly, by 1) assigning the deterministic to a variable, and 2) use a list for contributions and .append the elements to it, similarly to your original model. I think this is a bit easier to follow:

df_features  = df[features]
n_obs, n_features = df_features.shape
coords = {'time':df_features.index,
          'features':features, 
          'all_vars':features}

if control_vars:
    df_controls  = df[control_vars]
    _, n_controls = df_controls.shape
    coords.update({'controls':control_vars,
                   'all_vars':features + control_vars})

with pm.Model(coords=coords) as basic_model_2:
    X = pm.MutableData('feature_data', df_features.values, dims=['time', 'features'])
    y = pm.MutableData('targets', df['y'].values.squeeze(), dims=['time'])

    n_obs = X.shape[0]
    
    betas = pm.HalfNormal('beta', sigma = 2, dims=['features'])
    decays = pm.Beta('decay', alpha=3, beta=3, dims=['features'])
    sat = pm.Gamma('sat', alpha=3, beta=1, dims=['features'])
    contributions = []
    
    for i in range(n_features):
        x = logistic_function(geometric_adstock_tt(X[:, i], decays[i]),sat[i])*betas[i]
        contributions.append(x)
        
    if n_controls > 0:
        Z = pm.MutableData('control_data', df_controls.values, dims=['time', 'controls'])
        control_betas = pm.Normal('control_beta', sigma = 2, dims=['controls'])
        contributions.append(Z @ control_betas)
    
    mu = pm.Deterministic("contributions", tt.stack(contributions).T, dims=['time', 'all_vars'])
    sigma = pm.HalfNormal('sigma', sigma=1)
    
    y_hat = pm.Normal("y_hat", mu=mu.sum(axis=-1), sigma=sigma, observed=y, shape=X.shape[0], dims=['time'])

Which generates the following graph:

Which should be equivalent to yours. Note that I set a coord on the time index, so you’ll have to pass an update for that when you do pm.set_data for the out-of-sample.