Smooth Local Linear Trend with customDist

I am trying to code a smooth local linear trend model (a local linear trend model with level innovations disabled):

y[t] = level[t] + obs_err
level[t+1] = level[t] + trend[t]
trend[t+1] = trend[t] + trend_err (not as I wrote before trend[t+1] = trend[t+1] + trend_err)

with initial values level[0] and trend[0]. Thus, we have: y[0] = level[0]+obs_err

Using scan and a customDist, would the following implementation be correct? Notice that y is updated at the beginning in ll_step contrary to the many examples I saw in this blog where y is usually updated at the end.

lags = 1
timeseries_length = 100

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    trend = trend_tm1 + pm.Normal.dist(0, trend_sigma)
    level = level_tm1 + trend_tm1
    # This is what I saw in some other similar examples
    # y = pm.Normal.dist(level, obs_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])


def ll_dist(level0, trend0, trend_sigma, obs_sigma, size):
    (level, trend, y), _ = pytensor.scan(
        fn=ll_step,
        outputs_info = [level0, trend0, None],
        non_sequences=[trend_sigma, obs_sigma],
        n_steps=timeseries_length - lags,
        strict=True,
    )

    return y    

coords = {
    "order": ["level", "trend"],
    "lags": range(-lags, 0),
    "steps": range(timeseries_length - lags),
    "timeseries_length": range(timeseries_length),
}
with pm.Model(coords=coords, check_bounds=False) as model:
    # Priors
    trend_sigma = pm.Gamma("trend_sigma", alpha=2, beta=50)
    
    # Hyperpriors for the means and stds
    level_mu0 = pm.Normal("level_mu0", mu=0, sigma=1)
    level_sigma0 = pm.Gamma("level_sigma0", alpha=5, beta=5)
    
    trend_mu0 = pm.Normal("trend_mu0", mu=0, sigma=1)
    trend_sigma0 = pm.Gamma("trend_sigma0", alpha=5, beta=5)

    # Priors with hyperpriors
    level0 = pm.Normal("level0", mu=level_mu0, sigma=level_sigma0)
    trend0 = pm.Normal("trend0", mu=trend_mu0, sigma=trend_sigma0)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=0.05)
    
    ll_steps = pm.CustomDist(
        "ll_steps",
        level0,
        trend0,
        trend_sigma,
        obs_sigma,
        dist=ll_dist,
        dims=("steps",),
    )

pm.model_to_graphviz(model)

It seems correct at first glance, does it work?

There are some small errors in your notation, you should have trend[t+1] = trend[t] + trend_error.

The way you write y, observations will be back-shifted by one period relative to hidden states, but otherwise everything is completely identical. That is, at the 10th position of the output vector, you will have the value of the system after 9 iterations. That seems somewhat counter-intuitive to me.

@jessegrabowski Thanks for your replying. Yes you are right:
it is trend[t+1] = trend[t] + error.
Indeed, it sounds counter-intuitive, though it ensures y[0] = level[0] + obs_err, which is correct, no?

However, I am getting an error. This is the full code:

def sampler_kwargs():
    return dict(
        nuts_sampler="nutpie",
        chains=5,
        target_accept=0.925, 
        draws=300, 
        tune=700,
        nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
        random_seed=101190
    )
lags = 0
timeseries_length = 100

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    trend = trend_tm1 + pm.Normal.dist(0, trend_sigma)
    level = level_tm1 + trend_tm1
    # This is what I saw in some other similar examples
    # y = pm.Normal.dist(level, obs_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])


def ll_dist(level0, trend0, trend_sigma, obs_sigma, size):
    (level, trend, y), _ = pytensor.scan(
        fn=ll_step,
        outputs_info = [level0, trend0, None],
        non_sequences=[trend_sigma, obs_sigma],
        n_steps=timeseries_length - lags,
        strict=True,
    )

    return y    

coords = {
    "order": ["level", "trend"],
    #"lags": range(-lags, 0),
    "steps": range(timeseries_length - lags),
    "timeseries_length": range(timeseries_length),
}
with pm.Model(coords=coords, check_bounds=False) as model:
    # Priors
    trend_sigma = pm.Gamma("trend_sigma", alpha=2, beta=50)
    
    # Hyperpriors for the means and stds
    level_mu0 = pm.Normal("level_mu0", mu=0, sigma=1)
    level_sigma0 = pm.Gamma("level_sigma0", alpha=5, beta=5)
    
    trend_mu0 = pm.Normal("trend_mu0", mu=0, sigma=1)
    trend_sigma0 = pm.Gamma("trend_sigma0", alpha=5, beta=5)

    # Priors with hyperpriors
    level0 = pm.Normal("level0", mu=level_mu0, sigma=level_sigma0)
    trend0 = pm.Normal("trend0", mu=trend_mu0, sigma=trend_sigma0)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=0.05)
    y_obs =  pm.Data('y_obs', np.array([0.01*(i+1) for i in range(100)]))
    ll_steps = pm.CustomDist(
        "ll_steps",
        level0,
        trend0,
        trend_sigma,
        obs_sigma,
        dist=ll_dist,
        observed = y_obs,
        dims=("steps",),
    )

    trace = pm.sample(**sampler_kwargs())


pm.model_to_graphviz(model)

The graph is plotted but I am getting the error:
MissingInputError: NominalGraph is missing an input: *2-

Any ideas, why?

I also added the configuration of the sampler for completeness.

Initial states are never included by scan, and need to be concatenated manually. Writing y_{t+1} = x_{t+1} + \eta_{t+1} or y_t = x_t + \eta_t are completely equivalent.

I’ll see if I can run your code to get a better grip on what’s wrong.