Help with recreating the PyMC marketing sample_posterior_predictive function

I have a PyMC MMM model defined on a training set by

model_config={
    "SEARCH": Prior("Normal", mu=0, sigma=0.2),
    "PERFORMANCE_MAX": Prior("Normal", mu=0, sigma=0.2),
    "SHOPPING": Prior("Normal", mu=0, sigma=0.2),
    "acquisition": Prior("Normal", mu=0, sigma=0.2),
    "remarketing": Prior("Normal", mu=0, sigma=0.2),
    "kl_purchases": Prior("Normal", mu=0, sigma=0.2),
}
# Initialize the MMM model  
mmm = MMM(
    adstock=GeometricAdstock(l_max=12),
    saturation=HillSaturation(),
    date_column="date",
    channel_columns=channel_columns,
    yearly_seasonality=12,
    sampler_config = {"target_accept": 0.95},
    model_config= model_config,
)
# Fit and sample
mmm.fit(X_train, y_train, decompose_contributions=True)

In order to fully understand the model being used, I would like to recreate the function

idata_counter = mmm.sample_posterior_predictive(X_counter, extend_idata=False, include_last_observations=True)

for some counterfactual data. Here is the code of the function with which I am trying to do that.

from sklearn.preprocessing import MaxAbsScaler

# Fit scaler on training channel data
channel_scaler = MaxAbsScaler()
channel_scaler.fit(X_train[mmm.channel_columns])  # Use training X used for model fit

def predict_mmm_stochastic_multiple_days_lmax(
    X: pd.DataFrame,
    mmm,
    X_train: pd.DataFrame,
    n_samples: int = 4000,
    ci: float = 0.9,
    l_max: int = 12,
    add_noise: bool = True
) -> dict:
    """
    Predict revenue stochastically for multiple days using posterior samples from the MMM model,
    taking into account geometric adstock with window l_max.

    Returns:
        dict: {
            "mean": np.ndarray,
            "lower": np.ndarray,
            "upper": np.ndarray,
            "samples": np.ndarray (shape: n_samples x n_days)
        }
    """
    idata = mmm.idata
    posterior = idata.posterior

    # Flatten
    flat = posterior.stack(sample=("chain", "draw"))

    flat = flat.isel(sample=slice(0, n_samples)) 

    # Prepare data: prepend l_max days from training to get adstock right
    X_hist = X_train.tail(l_max)
    X_full = pd.concat([X_hist, X], ignore_index=True)
    X_full = X_full.sort_values(by="date").reset_index(drop=True)
    dates = pd.to_datetime(X_full["date"]).reset_index(drop=True)

    # Scale channels
    X_channels_scaled = pd.DataFrame(
        channel_scaler.transform(X_full[mmm.channel_columns]),
        columns=mmm.channel_columns
    ).reset_index(drop=True)

    n_days = len(X)
    all_samples = np.zeros((n_samples, n_days))

    for i in range(n_samples):
        s = flat.isel(sample=i)
        intercept = s["intercept"].item()

        for d in range(l_max, len(X_full)):  # start after adstock warmup

            total = intercept
            d_out = d - (l_max)  # index into output prediction array

            for channel in mmm.channel_columns:
                alpha = s["adstock_alpha"].sel(channel=channel).item()
                beta = s["saturation_beta"].sel(channel=channel).item()
                kappa = s["saturation_kappa"].sel(channel=channel).item()

                # Apply geometric adstock over window [d-l_max+1, d]
                adstock_sum = 0.0
                for l in range(l_max):
                    idx = d - l
                    if idx >= 0:
                        spend = X_channels_scaled.loc[idx, channel]
                        adstock_sum += (alpha ** l) * spend

                x_adstock = adstock_sum
                x_adstock = max(x_adstock, 1e-8)
                response = beta * (x_adstock ** kappa) / (1 + x_adstock ** kappa)
                total += response

            # Add control variables if present
            if hasattr(mmm, "control_columns") and mmm.control_columns:
                for col in mmm.control_columns:
                    mu = s["control_mu"].sel(control=col).item()
                    control_val = X_full.loc[d, col]
                    total += mu * control_val

            seasonality = 0.0
            # Add seasonality if enabled
            if mmm.yearly_seasonality > 0:
                t = 2 * np.pi * dates[d].dayofyear / 365.25
                for j in range(1, mmm.yearly_seasonality + 1):
                    sin_term = s["gamma_fourier"].sel(fourier_mode=f"sin_{j}").item()
                    cos_term = s["gamma_fourier"].sel(fourier_mode=f"cos_{j}").item()
                    total += sin_term * np.sin(j * t) + cos_term * np.cos(j * t)
                    seasonality += sin_term * np.sin(j * t) + cos_term * np.cos(j * t)

            if add_noise:
                y_sigma = s["y_sigma"].item()
                noise = np.random.normal(0, y_sigma)
                total += noise

            # Store result for this sample/day
            all_samples[i, d_out] = total


    samples_original = mmm.get_target_transformer().inverse_transform(all_samples)

    
    # Then compute summary statistics on original scale
    lower = np.percentile(samples_original, (1 - ci) / 2 * 100, axis=0)
    upper = np.percentile(samples_original, (1 + ci) / 2 * 100, axis=0)
    mean = samples_original.mean(axis=0)


    return {
    "mean": mean,
    "lower": lower,
    "upper": upper,
    "samples": samples_original,  # <-- instead of all_samples
    }

I tried reading through the library to understand what I could be missing but with not much succes. I’m getting a noticable discrepancy between the two approaches. The channel_scaler is defined manually and target_scaler is given by Pipeline(steps=[(‘scaler’, MaxAbsScaler())]).

Any help and ideas would be greatly appreciated!