PyMC-extras state-space model for dynamic control

Hi,

I wonder if anyone has tried using state space models for dynamic control similar to how Kalman filters are used? Like this: Kalman-and-Bayesian-Filters-in-Python/06-Multivariate-Kalman-Filters.ipynb at master · rlabbe/Kalman-and-Bayesian-Filters-in-Python · GitHub

That seems to be a good use case to deal with non-linearities/non-Gaussian systems.

On surface seems to be easy to feed the control as exogenous variable, but how I struggle conceptualize how to do classical Kalman predict/update loop in pymc.

cc @jessegrabowski i know you like this stuff :slight_smile:

Yes, I have been meaning to add tools to do this for a long time.

First, we have pt.linalg.solve_discrete_are, so once you define a linear-quadratic loss function, you can plug it in, along with the A and B matrices from the state transition equation, to get the sequence of optimal actions. Solve this Algebraic Riccati equation is at the heart of dynamic control.

What I mean by A and B is:

x_t = A x_{t-1} + d + B u_t + \varepsilon_t

So A holds your state transition dynamics (it’s called T everywhere in pymc-statespace, following statsmodels (who follow Durbin and Koopsman)), and B is your mapping from control inputs u_t to the state space.

Note though that we can define \tilde{x} = \begin{bmatrix} x \\ u \end{bmatrix} and T = \begin{bmatrix} A & 0 \\ 0 & B \end{bmatrix} (and \tilde{d} = \begin{bmatrix}d \\ 0 \end{bmatrix} but nobody cares about d). Convince yourself that now:

\tilde x_t = T \tilde x_{t-1} + \tilde{d} + \varepsilon_t

Is the same system.

So you can approach this in at least two ways:

  1. You can make a custom statespace with this block diagonal structure, and fit it on a sequence of states and recorded control inputs (which you will treat like endogenous states). Once you’re done, you split the T back into A and B and feed them into solve_discrete_are.
  2. Open a PR that allows users to directly pass B and u to the existing kalman filters :slight_smile:

This series of lectures about LQ control are quite nice. Translating them to PyMC would be a huge benefit to the community.

2 Likes

On second reading, it doesn’t seem you want an optimal policy, just to do online filtering? If that’s the case, the setup is much easier. Just compile the required pytensor functions and feed in data:

import pytensor
import pytensor.tensor as pt
from pymc_extras.statespace.filters import StandardFilter
import numpy as np

# Use the model from the textbook "dog tracking" example
dt = pt.dscalar('dt')
sigma_sq = pt.dscalar('sigma_sq')
measurement_sigma_sq = pt.dscalar('measurement_sigma_sq')
data = pt.tensor('data', shape=(1,))


# All the names are different from the textbook, we are consistent with Durbin and Koopsman
# In general, engineers and economists disagree on the names, but shuffle the same letters :)
x = pt.tensor('x', shape=(2,))
P = pt.tensor('P', shape=(2, 2))

c = pt.zeros((2,))
d = pt.zeros((1,))

# T = [[1.0, dt],
#      [0.0, 1.0]]
T = pt.eye(2)[0, 1].set(dt)
Z = pt.zeros((1, 2))[0, 0].set(1.)
R = pt.eye(2)
H = pt.eye(1) * measurement_sigma_sq

# This equation comes from textbooks, look up "Van Loan discretization"
Q = sigma_sq * pt.stacklists([[dt ** 4 / 4, dt ** 2 / 2],
                              [dt ** 2 / 2, dt]])

# This is currently "off label" usage, so we have to hack a bit...
kf = StandardFilter()
kf.cov_jitter = 1e-8
kf.n_endog = 1
kf.n_states = 2
kf.n_shocks = 1

# Super important! In Durbin and Koopsman, x0 is a *prediction*, so one has to run the filter first
# In Labbe, x0 is a *filtered value*, so one predicts first.

# Unused return values are: y_hat, F, ll
x_filtered, P_filtered, *_ = kf.update(x, P, data, d, Z, H, pt.isnan(y).all())
x_predicted, P_predicted = kf.predict(x_filtered, P_filtered, c, T, R, Q)

f_filter = pytensor.function([x, P, data, dt, sigma_sq, measurement_sigma_sq], 
                             [x_filtered, P_filtered], 
                             on_unused_input='ignore')
f_predict = pytensor.function([x, P, data, dt, sigma_sq, measurement_sigma_sq], 
                              [x_predicted, P_predicted], 
                              on_unused_input='ignore')

Reproduce the “dog tracking” simulation:

import copy
import math
import numpy as np
from numpy.random import randn

def compute_dog_data(z_var, process_var, count=1, dt=1.):
    "returns track, measurements 1D ndarrays"
    x, vel = 0., 1.
    z_std = math.sqrt(z_var) 
    p_std = math.sqrt(process_var)
    xs, zs = [], []
    for _ in range(count):
        v = vel + (randn() * p_std)
        x += v*dt
        xs.append(x)
        zs.append(x + randn() * z_std)        
    return np.array(xs), np.array(zs)

x_hat = np.array([0, 0])
P_hat = np.diag([500., 49.])
innovation_var = 0.01
measurement_var = 10.0

n_steps = 50
dt_val = 1.0
dt_values = np.full(n_steps, dt_val)
hidden_states, data = compute_dog_data(z_var=measurement_var, 
                                       process_var=innovation_var, 
                                       count=n_steps,
                                       dt=dt_val)

state_history = np.zeros((n_steps, 4))
cov_history = np.zeros((n_steps, 8))

state_history[0, 2:] = x_hat
cov_history[0, 4:] = P_hat.ravel()

for t, (y, delta_t) in enumerate(zip(data, dt_values)):
    x_tt, P_tt = f_filter(x=x_hat, P=P_hat, data=y[None], dt=delta_t, 
                          sigma_sq=innovation_var, 
                          measurement_sigma_sq=measurement_var)
    x_hat, P_hat = f_predict(x=x_tt, P=P_tt, data=y[None], dt=delta_t,
                             sigma_sq=innovation_var,
                             measurement_sigma_sq=measurement_var)
    state_history[t, :] = np.r_[x_tt.ravel(), x_hat.ravel()]
    cov_history[t, :] = np.r_[P_tt.ravel(), P_hat.ravel()]

Plot results:

import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(14, 4))
ax.plot(dt_values.cumsum(), data)

mu_filtered = state_history[:, :2]
std_filtered = np.diagonal(cov_history[:, :4].reshape((-1, 2, 2)), axis1=1, axis2=2)

low = mu_filtered - 2 * std_filtered
high = mu_filtered + 2 * std_filtered

ax.plot(dt_values.cumsum(), mu_filtered[:, 0])

ax.fill_between(dt_values.cumsum(), low[:, 0], high[:, 0], color='tab:orange', alpha=0.25)
ax.plot(dt_values.cumsum(), hidden_states)

Result:

For the record, I’d still like it if someone came and worked on optimal linear-quadratic control problems. I think there are interesting applications in e.g. marketing :slight_smile:

Also I’d like help making some tools to simplify all this setup, issue here.

At risk one spamming, one more note. There’s no control in the dog tracking example, or anywhere in the linked chapter of Labbe, although he does mention that the KF can handle control inputs. If you want to know an optimal policy to decide the control inputs in a given state x_t and given dynamics encoded in your transition function, this is where all the algebraic riccati equation stuff comes in.

I’d like to cook up an example where we do optimal control to obtain and use a policy function with pytensor/pymc but it’ll have to wait for another day.

@jessegrabowski Thank you very much for such an extensive response! No spamming at all. Ideally yes, I want an optimal control policy. I need to go through your examples in a bit more detail.

The issue here Online Learning with State Space Linear Kalman Filter · Issue #540 · pymc-devs/pymc-extras · GitHub is very similar to what I think I need :slight_smile: . I can help if I can, although I’m not very familiar with pytensor, more on the pymc side.

Looking forward to your optimal control example.

Can you please correct [Math Processing Error] in your first post*?*

Just to be clear, the control problem (learning an optimal policy) is a fundamentally different thing than the filtering problem (learning the system dynamics). You can do them together, but they aren’t somehow joined at a deep level. You can have any black-box policy function \pi(x_t) = u_t to control the system. LQ is nice because it starts from the same linear setup as the KF, but it’s not the only tool to be used here. If you start from already knowing the system dynamics and you want a policy function, you have a fundamentally different problem from what a KF gives you, online or not.

Yes, I get that.

I already have some “home-grown” KF implementation, which I ideally want to migrate to a pymc state-space model. The measurements are done in the integer space, I think it’s much easier to let pymc do the hidden state estimation rather than mess with custom code producing the variances.

My other goal is to re-implement the feedback control (which is some heuristic that works) in a more principled way. LQR sounds really nice, as it allows to set a trade-off between the state being far away from desired vs costs of turning the knobs.