Moving average component for time-series modelling

Hello pymc community!
I’d like to write myself a simple bayesian model for time-series forecasting, and I’d like to have a MA(q) component in it (where q is not fixed at implementation). I’ve started by writing a model with this component alone. Comparison with statsmodels on dummy data gives me some confidence that my code is doing what I think it should. But the pm.Model I’ve got is rather slow to sample. I would like to ask for Your opinions and suggestions as to whether there are some modifications or reparametrizations that I could apply. Here’s my scripts:

import numpy as np
import matplotlib.pyplot as plt
import scipy
import statsmodels.api as sm
import pymc as pm
import arviz as az
import aesara
import aesara.tensor as at

# generate dummy data
rng = np.random.default_rng(0)

def ma_process(n, theta, mu, rng):
    white_noise = rng.normal(size=n + len(theta))
    ma = scipy.signal.convolve(white_noise, np.concatenate([[1], theta]), mode='valid')
    return mu + ma

n = 300
mu_gt = rng.normal(0, 4)
q_gt = 4
theta_gt = np.abs(rng.normal(0, 0.5, size=q_gt))
y_ma = ma_process(n, theta_gt, mu_gt, rng)

plt.figure(figsize=(12, 3))
plt.plot(y_ma);

# fit MA from statsmodels to have some sanity-check
sm_results_ma = sm.tsa.arima.ARIMA(y_ma, order=(0, 0, len(theta_gt) + 1)).fit()

# pick out inferred statsmodels params
mu_sm = sm_results_ma.params[0]
mu_sm_interval = np.array(sm_results_ma.summary().tables[1].data[1][-2:]).astype('float')
theta_sm = sm_results_ma.maparams
theta_sm_intervals = np.array(
    [row[-2:] for row in sm_results_ma.summary().tables[1].data[2:2 + len(theta_sm)]]
).astype('float')

# my pymc model

def residual_fn(
        observed,
        past_q_residuals,
        mu,
        theta_rev
    ):
    prediction = mu + pm.math.dot(past_q_residuals, theta_rev)
    residual = observed - prediction
    return residual, at.concatenate([past_q_residuals[1:], [residual]])


with pm.Model() as model_ma:

    observed = pm.MutableData('observed', y_ma)
    
    mu = pm.Normal('mu', 0, 10)
    sigma = pm.HalfNormal('sigma', 10)
    
    q = q_gt + 1
    theta = pm.Truncated(
        'theta',
        pm.Laplace.dist(0, 0.2, shape=q),
        lower=-1,
        upper=1
    )
    
    first_q_residuals = pm.Normal('first_q_residuals', mu=0, sigma=sigma, shape=q)
    
    (residual, _), _ = aesara.scan(
        fn=residual_fn,
        sequences=observed,
        outputs_info=[None, first_q_residuals],
        non_sequences=[mu, theta[::-1]]
    )
    pm.Normal('err', mu=residual, sigma=sigma, observed=[0])
    
    MAP_ma = pm.find_MAP()
    trace_ma = pm.sample()

# pick out pymc inferred params
mu_map = MAP_ma["mu"]
mu_mcmc_mean = trace_ma.posterior.mu.mean(["chain", "draw"]).values
mu_mcmc_interval = az.hdi(trace_ma.posterior.mu, hdi_prob=0.95).mu.values
theta_map = MAP_ma['theta']
theta_mcmc_mean = trace_ma.posterior.theta.mean(["chain", "draw"]).values
theta_mcmc_intervals = az.hdi(trace_ma.posterior.theta, hdi_prob=0.95).theta.values

# make plots comparing the sm- and pymc-inferred params
plt.figure(figsize=(6, 2))
plt.title(r"$\mu$")
plt.barh([0, 1, 2, 3],
         [mu_gt, mu_sm, mu_mcmc_mean, mu_map],
         xerr=np.array([[0, 0], np.abs(mu_sm_interval - mu_sm), np.abs(mu_mcmc_interval - mu_mcmc_mean), [0, 0]]).T)
plt.yticks([0, 1, 2, 3], labels=["ground truth", "statsmodels", "pymc mcmc", "pymc map"]);

image

idx = np.arange(1, len(theta_mcmc_mean) + 1)
plt.plot(idx[:-1], theta_gt, "--o", lw=1, label="ground truth")
plt.errorbar(idx, theta_sm,
             yerr=np.abs(theta_sm_intervals.T - theta_sm),
             linestyle="--", marker="o", elinewidth=1, lw=1, capsize=3,
             label="statsmodels [CI 0.95]")
plt.errorbar(idx, theta_mcmc_mean,
             yerr=np.abs(theta_mcmc_intervals.T - theta_mcmc_mean),
             linestyle="--", marker="o", elinewidth=1, lw=1, capsize=3,
             label="pymc mcmc [HDI 0.95]")
plt.plot(idx, theta_map, linestyle="--", marker="o", lw=1, label="pymc MAP")
plt.axhline(0, c="k")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1));