Smooth Local Linear Trend with customDist

@jessegrabowski Thanks for the comment. Does this mean that the first time this function is called

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    level = level_tm1 + trend_tm1
    trend = pm.Normal.dist(trend_tm1, trend_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])

level_tm1 and trend_tm1 are not the initial states? I was assuming that they are respectively level[0] and trend[0] and, therefore,

y = pm.Normal.dist(level_tm1, obs_sigma)

would encode y[0] ~ N(level[0], obs_sigma) following

I cleaned a little bit my code but it is now failing at line 116 of pymc/logprob/scan.py

assert oo_var in new_outer_input_vars 

but I was not able to fix it. The graph is produced. Please comment out the trace line to reproduce the error.

def sampler_kwargs():
    return dict(
        nuts_sampler="nutpie",
        chains=5,
        cores=6,
        target_accept=0.925, 
        draws=3000, 
        tune=3000,
        nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
        random_seed=42
    )

# Simulate a Local Linear Trend model
# Set seed for reproducibility
np.random.seed(42)

# Parameters
T = 100 # Number of time steps
sigma_trend = 0.005  # Trend disturbance std dev
sigma_obs = 0.002    # Observation noise std dev

# Simulate initial level and trend
mu_0 = np.random.normal(loc=-0.03, scale=0.02)    # initial level
beta_0 = np.random.normal(loc=0.02, scale=0.01)  # initial trend

# Initialize arrays
mu = np.zeros(T)
beta = np.zeros(T)
y = np.zeros(T)

# Set initial values
mu[0] = mu_0
beta[0] = beta_0
y[0] = mu[0] + np.random.normal(0, sigma_obs)

# Simulate time series
# Number of innovation is T-1
# Notice that beta[T-1] is not used since the last y[T-1] = mu[T-1] = mu[T-2]+beta[T-2]
for t in range(1, T):
    inno1 = np.random.normal(0, sigma_trend)
    inno2 = np.random.normal(0, sigma_obs)
    beta[t] = beta[t-1] + inno1
    mu[t] = mu[t-1] + beta[t-1]  # No innovation in level
    y[t] = mu[t] + inno2



# Smooth Local Linear Trend Model
lags = 1
timeseries_length = 100

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    level = level_tm1 + trend_tm1
    trend = pm.Normal.dist(trend_tm1, trend_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])


def ll_dist(level, trend, trend_sigma, obs_sigma, size):
    (level, trend, y), _ = pytensor.scan(
        fn=ll_step,
        outputs_info = [level, trend, None],
        non_sequences=[trend_sigma, obs_sigma],
        n_steps=timeseries_length - lags,
        strict=True,
    )

    return y    

coords = {
    "lags": range(-lags, 0),
    "steps": range(timeseries_length - lags),
    "timeseries_length": range(timeseries_length),
}
with pm.Model(coords=coords, check_bounds=False) as model:
    # Priors
    trend_sigma = pm.Gamma("trend_sigma", alpha=2, beta=50)
    
    # Hyperpriors for the means and stds
    level_mu0 = pm.Normal("level_mu0", mu=0, sigma=1)
    level_sigma0 = pm.Gamma("level_sigma0", alpha=5, beta=5)
    
    trend_mu0 = pm.Normal("trend_mu0", mu=0, sigma=1)
    trend_sigma0 = pm.Gamma("trend_sigma0", alpha=5, beta=5)

    # Priors with hyperpriors
    level0 = pm.Normal("level0", mu=level_mu0, sigma=level_sigma0)
    trend0 = pm.Normal("trend0", mu=trend_mu0, sigma=trend_sigma0)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=0.05)
    y_obs =  pm.Data('y_obs', y[:99])
    ll_steps = pm.CustomDist(
        "ll_steps",
        level0,
        trend0,
        trend_sigma,
        obs_sigma,
        dist=ll_dist,
        observed = y_obs,
        dims=("steps",),
    )

    #trace = pm.sample(**sampler_kwargs())


pm.model_to_graphviz(model)