Multivariate Random Walk with missing values

Hey all!

I am trying to use the multivariate normal random walk to forecast values in the future. The best way to do this (at least in JAGS) is to add nans to the observed data. I tried to do that using pymc3 because I want to try out some of the very nice VI tools, but I can’t quite get it to work. I have attached a minimal reproducible example where you can toggle on the nan and see it fail or leave it off and see it successfully sample.

Thanks!

import matplotlib.pyplot as plt
from scipy.linalg import cholesky
from pymc3.distributions import Continuous
import scipy as sp
import theano.tensor as T
import theano.tensor.nlinalg
import sys
import pymc3 as pm
import theano.tensor.slinalg as sla


X = np.random.normal(size=(3,3))
class mvNormalRandomWalk(Continuous):
    def __init__(self, mu=0., cov=1., *args, **kwargs):
        super(mvNormalRandomWalk, self).__init__(*args, **kwargs)
        self.cov = cov
        self.mu = mu
    
    def logp(self, x):
        mu = self.mu
        
        x_im1 = x[:-1]
        x_i = x[1:]
        
        L = sla.cholesky(self.cov)
        log_det = T.log(L.diagonal()).sum()
        delta = x_i - (x_im1+mu)
        
        solve_lower_triangular = sla.Solve(A_structure='lower_triangular', lower=True)
        Linv_delta = solve_lower_triangular(L,delta.T)
        
        k = L.shape[0]
        innov_like = -(0.5*k*T.log(2*np.pi) + log_det + 0.5*T.sum(Linv_delta*Linv_delta,axis=0))
        return T.sum(innov_like)

n_samples = 5000
Sigma = np.random.randn(3,5)
Sigma = Sigma.dot(Sigma.T)


#TURN OFF OR ON
#X[0,0] = np.nan


with pm.Model() as model:
    mu = pm.MvNormal('mu',mu=np.zeros(3), cov=np.eye(3),shape=3)
    likelihood = mvNormalRandomWalk('y',mu=mu,cov=Sigma,observed=X[0:3,0:3])
    step = pm.NUTS()
    trace = pm.sample(n_samples, step)

Did you have a look at the discussion in Multivariate Normal with missing inputs? I feel that the solution might apply.

Actually, random walk is modeling X_{t} - X_{t-1} \sim MvNormal(\theta), which makes missing value quite difficult to handle. I will need to think a bit more about it.