Making predictions in AR model, exploding AR states during posterior inference

Hi, thanks a lot for the input from the community, I am able to implement a complex model containing AR model and neural network. Below is the code for AR-

with pm.Model() as Mix:
def step_simple(x, A, Q, t, bais_latent):
    innov = pm.MvNormal.dist(mu=0.0, tau=Q)    
    next_x =,x) + innov + bais_latent
    t = t + 1
    return next_x, collect_default_updates(inputs = [x, A, Q, t, bais_latent], outputs = [next_x]) 

x0_ar = pm.Normal("xo_ar", 0, sigma=1, initval = init_ar, shape=(latent_factors))
sigmas_Q_ar = pm.InverseGamma('sigmas_Q_ar', alpha=3,beta=0.5, shape= (latent_factors))
Q_ar = pt.diag(sigmas_Q_ar)    
t = 0
ar_states_pt, ar_updates = pytensor.scan(step_simple, 
                                          non_sequences=[A_ar, Q_ar, t, bais_latent],
mix.register_rv(ar_states_pt, name='ar_states_pt')

These AR states are part of a neural network to make predictions. I have been using the above model to train a time series of length 120, training works fine with reasonable values of AR states. Now I want to use last state to predict the next state which will be used in neural network to make prediction. But when I run sample_posterior_predictive, the values of ar_states_pt explode.
I am doing

 with pm.Model() as prediction:
  prediction.register_rv(ar_states_pt, name='ar_states_pt')
 # use ar_states_pt[-1,:] to feed into neural network

then I am doing posterior predictive on the prediction model and the ar_states_pt explodes even though the trace value of ar_states_pt is reasonable. What is the problem here?

This all looks very reasonable to me, you are just missing a couple pieces:

  1. Wrap your scan in a function and have it return ar_states_pt. Pass this function to the dist argument of pm.CustomDist instead of using mix.register_rv, see here for an example.
  2. Pass t to scan via sequences, as in sequences=[pt.arange(T)].
  3. But instead of pt.arange(T), make the timesteps a pm.MutableData, as in time = pm.MutableData('time', np.arange(T))

After you do all that, you will be ready to make out-of-sample predictions in a prediction block like this (warning – untested code):

with pm.Model() as predictions:
    pm.DiracDelta('x0_ar', last_state_data)
    ar_states_pt = pm.Flat('ar_states_pt', shape=(1, latent_factors))
    idata_pred = pm.sample_posterior_predicted(idata, var_names=['ar_states_pt'])
1 Like

Thanks for your response. Probably a stupid question but how do I extract last_state_date? using trace?

I assumed it comes from your observed data. If not, yes you can just give it a sample/samples from your trace

Oh no, the last state comes from the learned model. It might seem stupid to do in this way, but my complete model is configured in this way and this is only the part.