Incorrect inference of hidden states in state space models

I am using state space models for time series modeling using the following model

X_t = X_{t-1} + c + N(0,\sigma^2) \\ y^i \sim Poisson(f(z^i,X_t))

For modeling state space I am using scan function

x0_ar = pm.Normal("xo_ar", 0, sigma=1, initval = init_ar, shape=(latent_factors), rng=rng)
sigmas_Q_ar = pm.InverseGamma('sigmas_Q_ar', alpha=3,beta=0.5, shape= (latent_factors), rng=rng)
Q_ar = pt.diag(sigmas_Q_ar) 

def step_simple(x, A, Q, bais_latent):
    innov = pm.MvNormal.dist(mu=0.0, tau=Q)    
    next_x = pm.math.dot(A,x) + innov + bais_latent

    return next_x, collect_default_updates( [next_x]) 

ar_states_pt, ar_updates = pytensor.scan(step_simple, 
                                          outputs_info=[x0_ar], 
                                          non_sequences=[A_ar, Q_ar, bais_latent],
                                          n_steps=T, 
                                          strict=True)
assert(ar_updates)

model_minibatch.register_rv(ar_states_pt, name='ar_states_pt', initval=pt.zeros((T, latent_factors)))
ar_states = pm.Deterministic("ar", pt.concatenate([x0_ar.reshape((1,latent_factors)), ar_states_pt], axis=0))

ar_states goes as input to f function. When I run the inference using ADVI, the estimates of ar_states do not follow the state space dynamics and they just fit the training data in whatever way. Am I missing anything in the training process? How to make sure that the ar_states follow the dynamics. Any suggestion @jessegrabowski?

Edit:1 Even when A is 0, the states are very well-fitting data. I would expect them to follow the distribution of noise+bias_latent, why is that?
Edit2: Out of T time points, I have data for T/2, so the model is inferring states super well for T/2 time points, but for rest of them, the values hover around the initial of the registered variable when A is an identity matrix. Below is an example where A is an identity matrix, so the time series should use just the previous time point and add some noise to it, but it is fitting to the data for half of the points and for rest it is converging towards initial of the registered random variable.

Your model just defines a prior over the sequence, but the posterior is free to be whatever it wants to be. You are just defining a random walk with drift, so each subsequent point is regularized towards the previous one, but if the innovation sigma is sufficiently large this doesn’t matter

1 Like

Thanks a lot for the response! Oh yes, that makes sense. So let’s say I generate the data using the below dynamics

X_t = AX_{t-1} + N(0,\sigma^2)

but in the PyMC, I define the model to be a random walk. So the posterior estimates would be super accurate even though the model I have specified is inaccurate, is my understanding right? Is there a way to infer that my model specification is bad? One way would be to say in the model that X_t is not a random variable, but a deterministic one given the initial state and \sigma^2, then I can estimate how far off are these deterministic states from the true value, does that make sense?

I guess what do you mean by “super accurate”? The error bars in the plot you posted are quite wide.

The usual tools apply for model validation: posterior predictive checks and whatnot. PSIS-LOO is also still valid-ish, see here. I tried to implement a time-series LOO here, but I never finished it. Help wanted!

If your model isn’t too large, more traditional cross-validation adapted to time series data – sliding window or expanding window – are good. Scikit-learn has tools for doing that. That’s equivalent to making a “porcupine plot” like I show in this notebook. This is an example of “dynamic forecasting”, which is just rolling the transition equation forward from a fixed point, and compute the cumulative error. Here’s a statsmodels example.

Note that my porcupine plot is fake in the sense that I don’t re-fit the model before making each forecast trajectory, so there is data leakage. This is also true if you compute the LOO using az.loo or az.compare, but these can still be useful quantities for model comparison, as the LOO FAQ link above notes.

1 Like

Thanks a lot! The links you shared are very useful.