Smooth Local Linear Trend with customDist

I am trying to code a smooth local linear trend model (a local linear trend model with level innovations disabled):

y[t] = level[t] + obs_err
level[t+1] = level[t] + trend[t]
trend[t+1] = trend[t] + trend_err (not as I wrote before trend[t+1] = trend[t+1] + trend_err)

with initial values level[0] and trend[0]. Thus, we have: y[0] = level[0]+obs_err

Using scan and a customDist, would the following implementation be correct? Notice that y is updated at the beginning in ll_step contrary to the many examples I saw in this blog where y is usually updated at the end.

lags = 1
timeseries_length = 100

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    trend = trend_tm1 + pm.Normal.dist(0, trend_sigma)
    level = level_tm1 + trend_tm1
    # This is what I saw in some other similar examples
    # y = pm.Normal.dist(level, obs_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])


def ll_dist(level0, trend0, trend_sigma, obs_sigma, size):
    (level, trend, y), _ = pytensor.scan(
        fn=ll_step,
        outputs_info = [level0, trend0, None],
        non_sequences=[trend_sigma, obs_sigma],
        n_steps=timeseries_length - lags,
        strict=True,
    )

    return y    

coords = {
    "order": ["level", "trend"],
    "lags": range(-lags, 0),
    "steps": range(timeseries_length - lags),
    "timeseries_length": range(timeseries_length),
}
with pm.Model(coords=coords, check_bounds=False) as model:
    # Priors
    trend_sigma = pm.Gamma("trend_sigma", alpha=2, beta=50)
    
    # Hyperpriors for the means and stds
    level_mu0 = pm.Normal("level_mu0", mu=0, sigma=1)
    level_sigma0 = pm.Gamma("level_sigma0", alpha=5, beta=5)
    
    trend_mu0 = pm.Normal("trend_mu0", mu=0, sigma=1)
    trend_sigma0 = pm.Gamma("trend_sigma0", alpha=5, beta=5)

    # Priors with hyperpriors
    level0 = pm.Normal("level0", mu=level_mu0, sigma=level_sigma0)
    trend0 = pm.Normal("trend0", mu=trend_mu0, sigma=trend_sigma0)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=0.05)
    
    ll_steps = pm.CustomDist(
        "ll_steps",
        level0,
        trend0,
        trend_sigma,
        obs_sigma,
        dist=ll_dist,
        dims=("steps",),
    )

pm.model_to_graphviz(model)

It seems correct at first glance, does it work?

There are some small errors in your notation, you should have trend[t+1] = trend[t] + trend_error.

The way you write y, observations will be back-shifted by one period relative to hidden states, but otherwise everything is completely identical. That is, at the 10th position of the output vector, you will have the value of the system after 9 iterations. That seems somewhat counter-intuitive to me.

@jessegrabowski Thanks for your replying. Yes you are right:
it is trend[t+1] = trend[t] + error.
Indeed, it sounds counter-intuitive, though it ensures y[0] = level[0] + obs_err, which is correct, no?

However, I am getting an error. This is the full code:

def sampler_kwargs():
    return dict(
        nuts_sampler="nutpie",
        chains=5,
        target_accept=0.925, 
        draws=300, 
        tune=700,
        nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
        random_seed=101190
    )
lags = 0
timeseries_length = 100

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    trend = trend_tm1 + pm.Normal.dist(0, trend_sigma)
    level = level_tm1 + trend_tm1
    # This is what I saw in some other similar examples
    # y = pm.Normal.dist(level, obs_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])


def ll_dist(level0, trend0, trend_sigma, obs_sigma, size):
    (level, trend, y), _ = pytensor.scan(
        fn=ll_step,
        outputs_info = [level0, trend0, None],
        non_sequences=[trend_sigma, obs_sigma],
        n_steps=timeseries_length - lags,
        strict=True,
    )

    return y    

coords = {
    "order": ["level", "trend"],
    #"lags": range(-lags, 0),
    "steps": range(timeseries_length - lags),
    "timeseries_length": range(timeseries_length),
}
with pm.Model(coords=coords, check_bounds=False) as model:
    # Priors
    trend_sigma = pm.Gamma("trend_sigma", alpha=2, beta=50)
    
    # Hyperpriors for the means and stds
    level_mu0 = pm.Normal("level_mu0", mu=0, sigma=1)
    level_sigma0 = pm.Gamma("level_sigma0", alpha=5, beta=5)
    
    trend_mu0 = pm.Normal("trend_mu0", mu=0, sigma=1)
    trend_sigma0 = pm.Gamma("trend_sigma0", alpha=5, beta=5)

    # Priors with hyperpriors
    level0 = pm.Normal("level0", mu=level_mu0, sigma=level_sigma0)
    trend0 = pm.Normal("trend0", mu=trend_mu0, sigma=trend_sigma0)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=0.05)
    y_obs =  pm.Data('y_obs', np.array([0.01*(i+1) for i in range(100)]))
    ll_steps = pm.CustomDist(
        "ll_steps",
        level0,
        trend0,
        trend_sigma,
        obs_sigma,
        dist=ll_dist,
        observed = y_obs,
        dims=("steps",),
    )

    trace = pm.sample(**sampler_kwargs())


pm.model_to_graphviz(model)

The graph is plotted but I am getting the error:
MissingInputError: NominalGraph is missing an input: *2-

Any ideas, why?

I also added the configuration of the sampler for completeness.

Initial states are never included by scan, and need to be concatenated manually. Writing y_{t+1} = x_{t+1} + \eta_{t+1} or y_t = x_t + \eta_t are completely equivalent.

I’ll see if I can run your code to get a better grip on what’s wrong.

@jessegrabowski Thanks for the comment. Does this mean that the first time this function is called

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    level = level_tm1 + trend_tm1
    trend = pm.Normal.dist(trend_tm1, trend_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])

level_tm1 and trend_tm1 are not the initial states? I was assuming that they are respectively level[0] and trend[0] and, therefore,

y = pm.Normal.dist(level_tm1, obs_sigma)

would encode y[0] ~ N(level[0], obs_sigma) following

I cleaned a little bit my code but it is now failing at line 116 of pymc/logprob/scan.py

assert oo_var in new_outer_input_vars 

but I was not able to fix it. The graph is produced. Please comment out the trace line to reproduce the error.

def sampler_kwargs():
    return dict(
        nuts_sampler="nutpie",
        chains=5,
        cores=6,
        target_accept=0.925, 
        draws=3000, 
        tune=3000,
        nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
        random_seed=42
    )

# Simulate a Local Linear Trend model
# Set seed for reproducibility
np.random.seed(42)

# Parameters
T = 100 # Number of time steps
sigma_trend = 0.005  # Trend disturbance std dev
sigma_obs = 0.002    # Observation noise std dev

# Simulate initial level and trend
mu_0 = np.random.normal(loc=-0.03, scale=0.02)    # initial level
beta_0 = np.random.normal(loc=0.02, scale=0.01)  # initial trend

# Initialize arrays
mu = np.zeros(T)
beta = np.zeros(T)
y = np.zeros(T)

# Set initial values
mu[0] = mu_0
beta[0] = beta_0
y[0] = mu[0] + np.random.normal(0, sigma_obs)

# Simulate time series
# Number of innovation is T-1
# Notice that beta[T-1] is not used since the last y[T-1] = mu[T-1] = mu[T-2]+beta[T-2]
for t in range(1, T):
    inno1 = np.random.normal(0, sigma_trend)
    inno2 = np.random.normal(0, sigma_obs)
    beta[t] = beta[t-1] + inno1
    mu[t] = mu[t-1] + beta[t-1]  # No innovation in level
    y[t] = mu[t] + inno2



# Smooth Local Linear Trend Model
lags = 1
timeseries_length = 100

def ll_step(*args):
    level_tm1, trend_tm1, trend_sigma, obs_sigma = args
    y = pm.Normal.dist(level_tm1, obs_sigma)
    level = level_tm1 + trend_tm1
    trend = pm.Normal.dist(trend_tm1, trend_sigma)
    return (level, trend, y), collect_default_updates(inputs=args, outputs=[level, trend, y])


def ll_dist(level, trend, trend_sigma, obs_sigma, size):
    (level, trend, y), _ = pytensor.scan(
        fn=ll_step,
        outputs_info = [level, trend, None],
        non_sequences=[trend_sigma, obs_sigma],
        n_steps=timeseries_length - lags,
        strict=True,
    )

    return y    

coords = {
    "lags": range(-lags, 0),
    "steps": range(timeseries_length - lags),
    "timeseries_length": range(timeseries_length),
}
with pm.Model(coords=coords, check_bounds=False) as model:
    # Priors
    trend_sigma = pm.Gamma("trend_sigma", alpha=2, beta=50)
    
    # Hyperpriors for the means and stds
    level_mu0 = pm.Normal("level_mu0", mu=0, sigma=1)
    level_sigma0 = pm.Gamma("level_sigma0", alpha=5, beta=5)
    
    trend_mu0 = pm.Normal("trend_mu0", mu=0, sigma=1)
    trend_sigma0 = pm.Gamma("trend_sigma0", alpha=5, beta=5)

    # Priors with hyperpriors
    level0 = pm.Normal("level0", mu=level_mu0, sigma=level_sigma0)
    trend0 = pm.Normal("trend0", mu=trend_mu0, sigma=trend_sigma0)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=0.05)
    y_obs =  pm.Data('y_obs', y[:99])
    ll_steps = pm.CustomDist(
        "ll_steps",
        level0,
        trend0,
        trend_sigma,
        obs_sigma,
        dist=ll_dist,
        observed = y_obs,
        dims=("steps",),
    )

    #trace = pm.sample(**sampler_kwargs())


pm.model_to_graphviz(model)

I would have thought you could just write the generative model down directly as follows:

trend[0] = initial trend value (can be unknown and given a distribution or given as data)
level[0] = initial level value (ditto)
trend_err = pm.Normal.dist(0, trend_sigma)
for t in range(1, T):
    trend[t] = trend[t - 1] + trend_err[t]
    level[t] = level[t - 1] + trend[t]
y = pm.Normal.dist(trend, obs_sigma, observe = y_obs)

Warning: I almost certainly messed up the observe syntax and I just assumed that the normal distribution could be applied to containers of conforming sizes. If not, just move inside loop and use

trend[0] = initial trend value (can be unknown and given a distribution or given as data)
level[0] = initial level value (ditto)
for t in range(1, T):
    trend_err[t] = pm.Normal.dist(0, trend_sigma)
    trend[t] = trend[t - 1] + trend_err[t]
    level[t] = level[t - 1] + trend[t]
    y[t] = pm.Normal.dist(trend[t], obs_sigma, observe=y_obs[t])

This is one of the things where PyMC/PyTensor API is less ergonomic. You don’t want to define a graph using python loops (or, more general python control flow like if else statements).

For the loop case it yields massive computational non-vectorized graphs (basically a static unrolled loop). You usually have to use a Scan to define a symbolic loop which is a single node in the computational graph. It’s also a kind of tape operator that then enables backprop of the intermediate steps. JAX introduces the same construct for the same reason.

The CustomDist thing is separate. You can use one to wrap a scan based random process and it will figure our the logp for you. But you don’t really need it, it’s just a way so that it looks and acts like a pre-built RV/distribution.

pm.Normal.dist(trend[t], obs_sigma, observe=y_obs[t])

This is not valid syntax, you can’t pass observed to a dist. Also the rv returned by .dist is not registered in the model, so the way you wrote it it’s a no-op. Technicalities…

There’s a Model.register_rv method that accepts a random variable like the ones retuned by .dist and observed value, but users don’t usually work with that directly. They use some pm.Foo that accepts the observed arguments, creates an RV and calls the model method. That’s one of the things CustomDist does that makes it “feel and act” like generic distributions