Range vs list in the taps specification of scan for an AR(1) model

I was following this tutorial: Time Series Models Derived From a Generative Graph — PyMC example gallery and encountered an odd error when changing the number of lags from 2 to 1. If I specify taps parameter as “range(-1, 0)” it works, but if instead I specify “[-1]” it fails due to shapes mismatch. I’ve had several back and forths with scan because I’m still understanding how it works and I was wondering whether behind this behaviour there’s an important feature I’m missing (I’m working a lot with taps=[-1] or taps=[0] for exogenous variables).

import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pymc.pytensorf import collect_default_updates

rng = np.random.default_rng(42)

lags = 1  # Number of lags
timeseries_length = 100  # Time series length

def ar_step(x_tm1, rho, sigma):
    mu = x_tm1 * rho[0]
    x = mu + pm.Normal.dist(sigma=sigma)
    return x, collect_default_updates([x])


def ar_dist(ar_init, rho, sigma, size):
    ar_innov, _ = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": ar_init, "taps": [-1]}],  # fails
        # outputs_info=[{"initial": ar_init, "taps": range(-1, 0)}],  # works
        non_sequences=[rho, sigma],
        n_steps=timeseries_length - lags,
        strict=True,
    )
    return ar_innov

coords = {
    "lags": range(-lags, 0),
    "steps": range(timeseries_length - lags),
    "timeseries_length": range(timeseries_length),
}
with pm.Model(coords=coords, check_bounds=False) as model:
    rho = pm.Normal(name="rho", mu=0, sigma=0.2, dims=("lags",))
    sigma = pm.HalfNormal(name="sigma", sigma=0.2)
    #
    ar_init = pm.Normal(name="ar_init", sigma=0.5, dims=("lags",))
    #
    ar_innov = pm.CustomDist(
        "ar_dist",
        ar_init,
        rho,
        sigma,
        dist=ar_dist,
        dims=("steps",),
    )
    #
    ar = pm.Deterministic(
        name="ar", var=pt.concatenate([ar_init, ar_innov], axis=-1), dims=("timeseries_length",)
    )

The code above crashes with error:
ValueError: ar_dist has 2 dims but 1 dim labels were provided.
Changing to

        outputs_info=[{"initial": ar_init, "taps": range(-1, 0)}],

makes it work.

I imagine it has to do with treating a single valued variable as a scalar vs 1-dimensional vector, as I’ve also seen that writing rho instead of rho[0] in ar_step brings similar trouble, but I don’t yet see why [-1] adds an additional (unwanted) dimension while range(-1, 0) does not.

Any insight is appreciated (included general tips for debugging scan).

1 Like