Simple ARX model implementation

I would like to implement a simple auto-regressive model with exogenous inputs. The AR class does not seem to support this. I am aware of the VARIMAX model in pymc_experimental. But I don’t need a such a general model, and the examples do not work.

I tried to create a custom distribution, based on this notebook, but it throws IndexError: too many indices for array.

def get_model(obs, specs):
    with pm.Model() as model:
        pass

    obs_index = obs.index.to_numpy()
    model.add_coord("obs_idx", obs_index)
    model.add_coord("tr_cnt", obs_index)

    with model:
        t = pm.Data("t", obs_index, dims="obs_idx")
        y = pm.Data("y", obs.to_numpy().ravel(), dims=("obs_idx",))
        
        init_weight = pm.Normal.dist(**specs['init_weight'])
        coefs_weight = pm.Normal("coefs_weight", **specs['coefs_weight'])         
        sigma = pm.HalfNormal("sigma", specs['sigma'])

        common = {
            "steps": t.shape[0]-(specs['coefs_weight']['size']-1),
            "constant": True,
            "dims": "obs_idx"
        }
        
        # this works as expected, but as no exogenous components
        # ar_weight = pm.AR("ar_weight", rho=coefs_weight, sigma=sigma,
        #                  init_dist=init_weight, **common)

        
        # does not work
        ar_weight = pm.CustomDist("ar_weight", init_weight, coefs_weight, sigma, dist=ar_dist,
                                  dims=("obs_idx",))  
        
        llh = pm.Normal("llh", mu=ar_weight, sigma=sigma, observed=y, dims=("obs_idx",))

    return model
def ar_dist(rho, sigma, init, size):

    lags = 1
    trials = 25
    def ar_step(val, rho, sigma):
        mu = val * rho
        x = mu + pm.Normal.dist(sigma=sigma)
        return x, collect_default_updates([x])

    ar_innov, _ = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": init, "taps":range(-lags, 0)}],
        non_sequences=[rho, sigma],
        n_steps=trials-lags,
        strict=True,
    )

    return ar_innov

Could you please point me in the right direction here?

Why don’t the examples work?

I would need to see a full traceback to address the specific model here. But you should look at the updated ARMA example notebook as the best place to start.