I’m a bit confused by what you mean by multivariate in this context. Do you just want to fit several independent time series at the same time, e.g. by adding a batch dimension? The only wrinkle there is that the time dimension always has to be the left-most dimension in the scan, which is confusing since that’s usually where we think of the batch dimension going. Here’s an example:
import pytensor
import pytensor.tensor as pt
import pymc as pm
from pymc.pytensorf import collect_default_updates
import arviz as az
lags = 2  # Number of lags
trials = 100  # Time series length
n_timeseries = 5
def ar_dist(ar_init, rho, sigma, size):
    def ar_step(x_tm2, x_tm1, rho, sigma):
        mu = x_tm1 * rho[0] + x_tm2 * rho[1]
        x = mu + pm.Normal.dist(sigma=sigma)
        return x, collect_default_updates([x])
    ar_innov, _ = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": ar_init, "taps": range(-lags, 0)}],
        non_sequences=[rho, sigma],
        n_steps=trials - lags,
        strict=True,
    )
    return ar_innov
coords = {
    "lags": range(-lags, 0),
    "steps": range(trials - lags),
    "trials": range(trials),
    "batch": range(n_timeseries)
}
with pm.Model(coords=coords, check_bounds=False) as batch_model:
    rho = pm.Normal(name="rho", mu=0, sigma=0.2, dims=("lags", 'batch'))
    sigma = pm.HalfNormal(name="sigma", sigma=0.2, dims=('batch'))
    ar_init = pm.Normal(name="ar_init", dims=("lags", 'batch'))
    ar_innov = pm.CustomDist(
        "ar_dist",
        ar_init,
        rho,
        sigma,
        dist=ar_dist,
        dims=("steps", 'batch'),
    )
    ar = pm.Deterministic(
        name="ar", var=pt.concatenate([ar_init, ar_innov], axis=0), dims=('trials', "batch")
    )
With a parameter recovery exercise:
import matplotlib.pyplot as plt
with batch_model:
    prior_idata = pm.sample_prior_predictive()
    
test_data = prior_idata.prior.ar_dist.sel(chain=0, draw=0).values
with pm.observe(batch_model, {'ar_dist': test_data}):
    idata = pm.sample()
    
true_rhos = prior_idata.prior.rho.sel(chain=0, draw=0).values
az.plot_posterior(idata, var_names=['rho'], coords={'lags':-1}, ref_val=true_rhos[1].tolist(), grid=(1,5));
One nitpick on nomenclature. Passing the volatility process to a distribution via the sigma parameter doesn’t make shocks, it makes a kind of time-varying measurement error model. Your shocks are already defined in the scan model by the pm.Normal.dist term, which you could change to be Student T if you wanted fat-tailed innovations. I’m not sure these terms are needed at all.
