Finding best lag via correlation between PyTensor Variable and a fixed vector while defining a model

I am implementing a Bayesian Regression Model below. For each independent variable in the model, a lag is applied such that the correlation between the lagged independent variable and the dependent variable is the highest. The reproducible code is as follows. I am getting an index out of bounds error. My reproducible code is as follows:

## Create a simple MMM data 
import pandas as pd
from random import randint
import numpy as np
import pytensor.tensor as tt
import pytensor as pt
import pymc as pm
import pymc.sampling.jax as pmjax
import arviz as az

# # Disable most optimizations
# pt.config.optimizer = 'fast_compile'
# # or completely turn off optimizations
# # pt.config.optimizer = 'None'

# # Increase exception verbosity
# pt.config.exception_verbosity = 'high'

## Functions to get the best lag

def correlation_coefficient(X, Y):
    """
    Calculate the correlation coefficient between two theano tensors.
    """
    X_mean = tt.mean(X)
    Y_mean = tt.mean(Y)

    X_std = tt.std(X)
    Y_std = tt.std(Y)

    covariance = tt.mean((X - X_mean) * (Y - Y_mean))
    return covariance / (X_std * Y_std)

def create_lagged_vector(X, lag):
    # Function to create a lagged version of a vector
    if lag == 0:
        return X
    else:
        return tt.concatenate([tt.zeros(lag), X[:-lag]])

def find_optimal_lag(X1, y, max_lag):
    best_lag = 0
    best_correlation = -np.inf

    for lag in range(max_lag + 1):
        lagged_X1 = create_lagged_vector(X1, lag)
        correlation = correlation_coefficient(lagged_X1, y)

        # Update best lag if this is the highest correlation so far
        best_correlation = pm.math.switch(tt.gt(correlation,best_correlation),correlation,best_correlation)
        best_lag = pm.math.switch(tt.gt(correlation,best_correlation),lag,best_lag)
        # if correlation > best_correlation:
        #     best_lag = lag
        #     best_correlation = correlation

    return best_lag, best_correlation





# Generate date range
dates = pd.date_range(start="2021-01-01", end="2022-01-01")

data = {
    "date": dates,
    "gcm_direct_Impressions": [randint(10000, 20000) for _ in dates],
    "display_direct_Impressions" :[randint(100000,150000) for _ in dates],
    "tv_grps": [randint(30, 50) for _ in dates],
    "tiktok_direct_Impressions": [randint(10000, 15000) for _ in dates],
    "sell_out_quantity": [randint(150, 250) for _ in dates]
}
df = pd.DataFrame(data)
m = max(df['sell_out_quantity'].values)

print(f"Max sales Volume {m}")

channel_columns = [col for col in df.columns if 'Impressions' in col or 'grps' in col]

transform_variables = channel_columns


delay_channels = channel_columns

media_channels = channel_columns

target = 'sell_out_quantity'

### Transform each channel variable

data_transformed = df.copy()

numerical_encoder_dict = {}


for feature in transform_variables:
    # Extracting the original values of the feature.
    original = df[feature].values

    # Calculating the maximum value of the feature.
    max_value = original.max()

    # Dividing each value in the feature by the maximum value.
    transformed = original / max_value

    # Storing the transformed data back into the 'data_transformed' DataFrame.
    data_transformed[feature] = transformed

    # Storing the maximum value used for scaling in the dictionary.
    # This will be used for reversing the transformation if needed.
    numerical_encoder_dict[feature] = max_value



def adstock_transform(x, rate,max_lag):
    """ Apply adstock transformation with PyTensor.
    :param x: PyTensor tensor, original data for the channel
    :param rate: PyTensor tensor, decay rate of the adstock transformation
    :param max_lag: int, maximum lag to consider for the adstock effect
    :return: PyTensor tensor, transformed data
    """
    # Creating a tensor to store transformed values
    adstocked = tt.zeros_like(x)
    
    for i in range(max_lag, x.shape[0]):
        weights = tt.power(rate, tt.arange(max_lag + 1))
        adstocked = tt.set_subtensor(adstocked[i], tt.dot(x[i-max_lag:i+1][::-1], weights))
    
    return adstocked

### Create a model
response_mean = []

with pm.Model() as model_2:
    # Looping through each channel in the list of delay channels.
    for channel_name in delay_channels:
        print(f"Delay Channels: Adding {channel_name}")

        # Extracting the transformed data for the current channel.
        x = data_transformed[channel_name].values

        # Defining Bayesian priors for the adstock, gamma, and alpha parameters for the current channel.
        adstock_param = pm.Beta(f"{channel_name}_adstock", 2, 2)
        saturation_gamma = pm.Beta(f"{channel_name}_gamma", 2, 2)
        saturation_alpha = pm.Gamma(f"{channel_name}_alpha", 3, 1)
        rate = pm.Beta(f'{channel_name}_rate', alpha=1, beta=1)
        ### Getting a adstocked transformed vector
        transformed_X1 = tt.zeros_like(x)
        for xi in range(0, len(x)):
            if xi == 0:
                transformed_X1 = tt.set_subtensor(transformed_X1[xi],x[xi])
            else:

                transformed_X1 = tt.set_subtensor(transformed_X1[xi],(transformed_X1[xi-1]*rate)+x[xi])

        ## Uncover the best lag for each channel        

        max_lag = 17

        y = tt.as_tensor(df['sell_out_quantity'].values)

        best_lag,best_correlation=find_optimal_lag(transformed_X1, y, max_lag)

        lagged_X1 = tt.concatenate([tt.zeros(best_lag),transformed_X1[:-best_lag]])


        
        ### Apply hill transform

        transformed_X2 = tt.zeros_like(x)
        for i in range(1,len(x)):
            transformed_X2 = tt.set_subtensor(transformed_X2[i],(lagged_X1[i]**saturation_alpha)/(lagged_X1[i]**saturation_alpha+saturation_gamma**saturation_alpha))
        channel_b = pm.HalfNormal(f"{channel_name}_media_coef", sigma = m)
        response_mean.append(transformed_X2 * channel_b)

    intercept = pm.Normal("intercept",mu = np.mean(data_transformed[target].values), sigma = 3)
    sigma = pm.HalfNormal("sigma", 4)
    likelihood = pm.Normal("outcome", mu = intercept + sum(response_mean), sigma = sigma,
                           observed = data_transformed[target].values)

with model_2:
    trace = pmjax.sample_numpyro_nuts(1000, tune=1000, target_accept=0.95)
    
    trace_summary = az.summary(trace)

I have added few functions to calculate the best lag that each independent variable should be subjected to.

def correlation_coefficient(X, Y):
    """
    Calculate the correlation coefficient between two theano tensors.
    """
    X_mean = tt.mean(X)
    Y_mean = tt.mean(Y)

    X_std = tt.std(X)
    Y_std = tt.std(Y)

    covariance = tt.mean((X - X_mean) * (Y - Y_mean))
    return covariance / (X_std * Y_std)

def create_lagged_vector(X, lag):
    # Function to create a lagged version of a vector
    if lag == 0:
        return X
    else:
        return tt.concatenate([tt.zeros(lag), X[:-lag]])

def find_optimal_lag(X1, y, max_lag):
    best_lag = 0
    best_correlation = -np.inf

    for lag in range(max_lag + 1):
        lagged_X1 = create_lagged_vector(X1, lag)
        correlation = correlation_coefficient(lagged_X1, y)

        # Update best lag if this is the highest correlation so far
        best_correlation = pm.math.switch(tt.gt(correlation,best_correlation),correlation,best_correlation)
        best_lag = pm.math.switch(tt.gt(correlation,best_correlation),lag,best_lag)
        # if correlation > best_correlation:
        #     best_lag = lag
        #     best_correlation = correlation

    return best_lag, best_correlation

After generating the best_lag, this lag is then applied to the transformed_X1 variable the following way

lagged_X1 = tt.concatenate([tt.zeros(best_lag),transformed_X1[:-best_lag]])

When I use lagged_X1 for further processing I get the following error, which I am unable to solve.

IndexError                                Traceback (most recent call last)
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:

IndexError: index out of bounds

During handling of the above exception, another exception occurred:

IndexError                                Traceback (most recent call last)
Cell In[3], line 182
    178     likelihood = pm.Normal("outcome", mu = intercept + sum(response_mean), sigma = sigma,
    179                            observed = data_transformed[target].values)
    181 with model_2:
--> 182     trace = pmjax.sample_numpyro_nuts(1000, tune=1000, target_accept=0.95)
    184     trace_summary = az.summary(trace)

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/jax.py:662, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, idata_kwargs, nuts_kwargs, postprocessing_chunks)
    659 tic1 = datetime.now()
..
 - Join.0, Shape: (0,), ElemSize: 8 Byte(s), TotalSize: 0 Byte(s)
 TotalSize: 34246.0 Byte(s) 0.000 GB
 TotalSize inputs: 33941.0 Byte(s) 0.000 GB

Is there a way to solve this? Thanks in advance !