opened 01:19PM - 23 Nov 24 UTC
For system with many states, kalman filtering is slow because we have to compute… the so-called Kalman gain matrix, $$K = P_\text{prior}Z^T F^{-1}$$, with $$F = ZP_{\text{prior}}Z^T + H$$ From there, we can compute the posterior covariance matrix, $$P_\text{posterior} = P_\text{prior} - KFK^T $$ These matrices have the following shapes and meanings:
| name | description | shape |
| ------|---------------------------------------|---------------------------|
| Z | Map hidden states to observed states | $n_\text{obs} \times n_\text{hidden}$|
| P | Hidden state covariance matrix | $n_\text{hidden} \times n_\text{hidden}$|
| H | Measurement error covariance | $n_\text{obs} \times n_\text{obs}$|
| F | Estimated residual covariance | $n_\text{obs} \times n_\text{obs}$ |
| K | Kalman gain (optimal update to hidden states) | $n_\text{hidden} \times n_\text{obs}$|
Typically the number of observed states will be quite small relative to the number of hidden states, so the inversion of $F$ isn't so bad. But when the number of hidden states is large, computing $$P_\text{posterior}$$ requires at least 3 matrix multiplications. Actually it's more, because we use the "Joseph Form", $$P_\text{posterior} = (I - KZ) ) P_\text{prior} (I - KZ)^T + KHK^T$$, which guarantees the result is always positive semi-definite (because it avoids subtraction between two PSD matrices), but instead costs 9 matrix multiplications.
Then, to compute next-step forecasts, we have to do several more multiplications:
$$ P_\text{prior}^+ = T P_\text{posterior} T^T + R Q R^T $$
Where $$T$$ is the same size as $$P$$ (it's the system dynamics), while $$R$$ and $$Q$$ are relatively smaller (they have to do with how shocks enter the system).
My point is this is all quite expensive to compute. Interestingly though, $$P_\text{posterior}$$ does not depend on the data at all, and if there are no time-dependent matrices, the covariance matrix will converge to a fixed point given by:
$$P = TPT^T + RQR^T - TPZ^T(ZPZ^T +H)^{-1}ZPT^T$$
This is an [algebraic riccati equation](https://en.wikipedia.org/wiki/Algebraic_Riccati_equation), and can be solved in pytensor as `pt.linalg.solve_discrete_are(A=T.T, B=Z.T, Q=R @ Q @ R.T, R=H)`. Once we have this, we actually don't need to compute $P_\text{prior}$ or $P_\text{posterior}$, ever again, we can just use the steady-state $P$. Actually, once $P$ is fixed, so is $F^{-1}$, so the filter updates become *extremely* inexpensive.
How to use this is not 100% clear to me. I had previously made a `SteadyStateFilter` class that computed and used the steady-state $P$ from the first iteration. This means we don't do "online learning" for the first several steps. That *seems* OK to me, since we don't believe all that initial noise anyway. But on the other hand I've never seen this approach suggested in a textbook, so it makes me a bit suspicious that it's the right thing to do. I'm not against offering this option, but one negative is that JAX doesn't offer `solve_discrete_are`, and right now statespace is basically 100% dependent on JAX for sampling.
The "safer" option would be to use an `ifelse` in the `update` step to check for convergence. At every iteration of `kalman_step`, we can compute $||P_\text{prior}^- - P_\text{prior}^+||$ and stop computing updates once its below some tolerance. Here's a graph of the supremium norm for a structural model with level, trend, cycle, and 12 seasonal lags. $P_0$ was initialized to `np.eye(15) * 1e7`. The two plots are the same, but the right plot begins at t=20:

Here's a table of tolerance levels and convergence iterations:
Tolerance | Convergence Iteration
----------|------------------------|
1 | 25 |
1e-1| 49 |
1e-2| 108|
1e-3 | 216|
1e-4 | 337|
1e-5 |457|
1e-6 |583|
1e-7 |703|
1e-8 |827|
We could leave convergence tolerance as a free parameter for the user to play with. But we can see that if we pick `1e-2` for instance, anything after 100 time steps is basically free. This would be quite attractive for estimating large, expensive systems or extremely long time series.
opened 08:25AM - 16 Apr 24 UTC
enhancements
feature request
statespace
I'm quite interested in the results of [this paper](https://arxiv.org/pdf/2303.1… 6846.pdf). The authors derive closed-form gradients for backprop through Kalman Filters. Specifically equations 28-31.
They report a 38x speedup over autodiff gradients from PyTorch. I suspect (with no evidence) that the gradient computations are where the default PyMC sampler really fall down, so this might even make non-JAX sampling of SS models palatable.
opened 08:59PM - 26 Dec 24 UTC
enhancements
statespace
The most requested feature for the statespace module is to handle multiple time … series in a single MCMC run. This will require support for batch dimensions. ~~I think the easiest way to attack this will be to refactor the `KalmanFilter` class to be an `OpFromGraph`. My original class-based design was inspired by the statsmodels implementation, but it doesn't take full advantage of pytensor.~~
Thinking more about this, I don't know if KF needs to be an OFG as a first step. It might still be nice to have a `AbstractKalmanFilter` dummy that we can rewrite to specialized cases, but when I wrote this issue I was a bit obsessed with OFG. The custom gradients are still on my mind, though. So the next two sentences remain true:
An additional advantage of this will be the ability to define a custom gradient. See #332.
Finally, it will let us handle special case filters via rewrites, rather than asking the user to pick a filter up front.
These are the issues currently open that would give performance gains. 406 is the most important since it will allow us to actually start benchmarking performance across large panels of time series. 394 would allow for better performance on very long time series, and 332 promises to be an across-the-board speedup. These issues in pytensor are also relevant:
opened 08:00PM - 05 Jan 24 UTC
help wanted
feature request
numba
graph rewriting
performance
linalg
### Description
[This paper ](https://arxiv.org/abs/2309.03060) and [this lib… rary ](https://github.com/wilson-labs/cola)describe and implement a number of linear algebra simplifications that we can implement as graph rewrites. This issue is both a tracker for implementing these rewrites, and a discussion for how to handle them.
Consider the following graph:
```
x = pt.dmatrix('x')
y = pt.diagonal(x)
z = pt.linalg.inv(y)
```
If we could promise at rewrite time that `y` is diagonal, we could re-write the last operation as `z = 1/y`, exploiting the structure of the diagonal matrix. Other non-trivial examples exist, for example:
```
x = pt.dmatrix('x')
y = pt.eye(3)
z = pt.kron(x, y)
z_inv = pt.linalg.inv(z)
```
If we could promise at rewrite time that `z` is block diagonal, we could rewrite `z_inv = pt.kron(pt.linalg.inv(x), y)`, which is a much faster operation (since `x` is 3x smaller than `z`).
The linked paper and library list a huge number of such shortcut computations. The following is a list version of Table 1 from the paper. Under each function are the types of matrices for which a rewrite rule exists to accelerate the function. ~It would be nice to collaborative update this list with links to the COLA library where the relevant rewrite is implemented:~
Thanks to @tanish1729 for compiling a list of links to relevant rewrites. Missing links indicate no direct implementation
- [ ] [cholesky(A)](https://github.com/wilson-labs/cola/blob/main/cola/linalg/decompositions/decompositions.py)
- [x] [Identity](https://github.com/wilson-labs/cola/blob/main/cola/linalg/decompositions/decompositions.py#L123)
- [x] [Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/decompositions/decompositions.py#L128)
- [x] [Kronecker Product](https://github.com/wilson-labs/cola/blob/main/cola/linalg/decompositions/decompositions.py#L133)
- [x] [Block Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/decompositions/decompositions.py#L139)
- [ ] [inv(A)](https://github.com/wilson-labs/cola/blob/main/cola/linalg/inverse/inv.py)
- [x] [Identity](https://github.com/wilson-labs/cola/blob/main/cola/linalg/eig/eigs.py#L147)
- [x] [Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/inverse/inv.py#L142)
- [ ] [Triangular](https://github.com/wilson-labs/cola/blob/main/cola/linalg/inverse/inv.py#L147)
- [ ] [Permutation](https://github.com/wilson-labs/cola/blob/main/cola/linalg/inverse/inv.py#L121)
- [ ] Convolution
- [ ] Sum
- [ ] [Product](https://github.com/wilson-labs/cola/blob/main/cola/linalg/inverse/inv.py#L126)
- [x] [Kronecker Product](https://github.com/wilson-labs/cola/blob/main/cola/linalg/inverse/inv.py#L137)
- [x] [Block Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/inverse/inv.py#L132)
- [ ] Concatenation
- [ ] [eigs(A)](https://github.com/wilson-labs/cola/blob/main/cola/linalg/eig/eigs.py)
- [ ] [Identity](https://github.com/wilson-labs/cola/blob/main/cola/linalg/eig/eigs.py#L147)
- [ ] [Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/eig/eigs.py#L179)
- [ ] [Triangular](https://github.com/wilson-labs/cola/blob/main/cola/linalg/eig/eigs.py#L156)
- [ ] Convolution
- [ ] Sum
- [ ] Kronecker Product
- [ ] Block Diagonal
- [ ] [diag(A)/trace(A)](https://github.com/wilson-labs/cola/blob/main/cola/linalg/trace/diag_trace.py)
- [ ] [Identity](https://github.com/wilson-labs/cola/blob/main/cola/linalg/trace/diag_trace.py#L55)
- [ ] [Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/trace/diag_trace.py#L63)
- [ ] [Sum](https://github.com/wilson-labs/cola/blob/main/cola/linalg/trace/diag_trace.py#L71)
- [ ] [Scalar Product](https://github.com/wilson-labs/cola/blob/main/cola/linalg/trace/diag_trace.py#L84)
- [ ] Triangular
- [ ] Permutation
- [ ] [Kronecker Product](https://github.com/wilson-labs/cola/blob/main/cola/linalg/trace/diag_trace.py#L93)
- [ ] [Block Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/trace/diag_trace.py#L77)
- [ ] Concatenation
- [ ] [logdet(A)](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py)
- [x] [Identity](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py#L118)
- [x] [Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py#L133)
- [ ] [Triangular](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py#L161)
- [ ] [Permutation](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py#L170)
- [ ] [Scalar Product](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py#L125)
- [ ] [Kronecker Product](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py#L141)
- [ ] [Block Diagonal](https://github.com/wilson-labs/cola/blob/main/cola/linalg/logdet/logdet.py#L152)
- [ ] [Unary Operations](https://github.com/wilson-labs/cola/blob/main/cola/linalg/unary/unary.py) (such as log, pow, exp, sqrt)
In addition, potential re-writes could also be applied to Topelitz and Circulant matrices, although these are not covered by COLA.
The wrinkle to all this is that we would need more information about matrices as they enter and exit `Ops`. Right now, we're using tags to accomplish rewrites like this, see for example #303 and #459 . Some of these rewrites might be possible to do via inference. For example, a `pt.linalg.block_diag` `Op` always returns a block diagonal matrix, as does `pt.kron(pt.eye, A)`. `pt.diagonal` always returns a diagonal matrix, as does `pt.eye`; `pt.linalg.cholesky` always returns a triangular matrix, etc. Other potential types, like `block`, `positive`, `definite`, `psd`, `topelitz`, `circulant`, etc, would be less trivial to automatically detect.
The other issue is that as these type tags proliferate, we become more and more locked into a somewhat hack-y system for marking things. Perhaps putting some thought into how to handle this now will save some refactoring headaches down the road?
opened 08:39AM - 20 Oct 24 UTC
graph rewriting
performance
linalg
### Description
Given a block diagonal matrix formed of matrices $A \in \mathbb… R^{n \times m}$ and $B^{o \times p}$, such that:
$$ D = \begin{bmatrix}
A & 0 \\
0 & B
\end{bmatrix} $$
Computing the matrix multiplication $DC$ can be simplified. Define $C_1 \in \mathbb R^{n \times k}$ and $C_2^{o \times k}$ such that:
$$
C = \begin{bmatrix} C_1 \\
C_2 \end{bmatrix}
$$
then:
$$
\begin{align}
DC &=
\begin{bmatrix}
A & 0 \\
0 & B
\end{bmatrix}
\begin{bmatrix}
C_1 \\
C_2
\end{bmatrix} \\
&= \begin{bmatrix}
A C_1 \\
B C_2 \end{bmatrix}
\end{align}
$$
We can compute these smaller dot products then concatenate the results back together for a speedup. Code:
```python
import numpy as np
from scipy import linalg
rng = np.random.default_rng()
n = 1000
A, B = rng.normal(size=(2, n, n))
C = rng.normal(size=(2*n, 2*n))
A, B, C = map(np.ascontiguousarray, [A, B, C])
def direct(A, B, C):
X = linalg.block_diag(A, B)
return X @ C
def rewrite(A, B, C):
n = A.shape[0]
C_1, C_2 = C[:n], C[n:]
return np.concatenate([A @ C_1, B @ C_2])
np.allclose(direct(A, B, C), rewrite(A, B, C)) # True
```
Speed test:
```python
direct_time = %timeit -o direct(A, B, C)
rewrite_time = %timeit -o rewrite(A, B, C)
speedup_factor = (1 - rewrite_time.best / direct_time.best)
print(f'{speedup_factor:0.2%}')
75.8 ms ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
43.5 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
40.77%
```
opened 08:13AM - 20 Oct 24 UTC
graph rewriting
performance
linalg
### Description
This can be re-written according to the following relationshi… p:
$$
(A \otimes B) C = \text{vec}(B X A^T)
$$
Where $\otimes$ is the kronecker product, and the $\text{vec}$ operation ravels a matrix in column-major order. $X$ is a matrix formed by reshaping $C$ (in column-major order) to conform with $B$. This avoids working with the large kronecker product matrix, and instead gets the result in terms of the much smaller components. Code example:
```python
n = 100
a, b = np.random.normal(size=(2, n, n))
c = np.random.normal(size=(n ** 2, ))
def kronAB_C_clever(a, b, c):
return (b @ c.reshape((n, n)).T @ a.T).T.ravel()
def direct(a, b, c):
K = np.kron(a, b)
x2 = K @ c
%timeit kronAB_C_clever(a, b, c)
73.5 μs ± 3.61 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit direct(a, b, c)
245 ms ± 72.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
This trick is already used in PyMC [here](https://github.com/pymc-devs/pymc/blob/5352798ee0d36ed566e651466e54634b1b9a06c8/pymc/math.py#L221), but only in a limited context. PyMC applies this identity to `solve_triangular` as well, but it can (and should) also be applied to other types of solve.
opened 05:31AM - 05 Aug 24 UTC
enhancement
feature request
NumPy compatibility
linalg
### Description
Right now we have a function `pt.nlinalg.matrix_dot` that is … just a helper for doing repeated dot products. This is similar to but worse than `np.linalg.multi_dot`, because `multi_dot` also computes an optimal contraction path, and does the dot products in a smart order.
#722 added optimized contraction path logic, so we could include this into our `matrix_dot` function. While we're at it, we should rename the function to` multi_dot` to match the numpy API.
opened 11:31AM - 22 Nov 24 UTC
beginner friendly
feature request
gradients
linalg
### Description
QR is one of the few remaining linalg ops that is missing a g… radient. JAX code for the jvp is [here](https://github.com/jax-ml/jax/blob/73fa0f48cb0081fc69cb82dffe6ceb5433cdc446/jax/_src/lax/linalg.py#L1954), whic also includes [this derivation](https://j-towns.github.io/papers/qr-derivative.pdf). [This paper ](https://www.tandfonline.com/doi/abs/10.1080/10556788.2011.610454) also claims to derive the gradients for QR, but I find it unreadable.
Relatedly but perhaps worthy of a separate issue, [this paper ](https://arxiv.org/abs/1710.08717) derives gradients for the LQ decomposition, $$A = LQ$$, where $L$ is lower triangular and $Q$ is orthonormal ($$Q^TQ=I$$.) Compare this to QR, which gives you $$A = QR$$, where $$Q$$ is again orthonormal, but $$R$$ is upper triangular, and you see why I mention it in this issue. It wouldn't be hard to offer LQ as well.
opened 11:50AM - 22 Nov 24 UTC
enhancement
help wanted
feature request
NumPy compatibility
linalg
### Description
Numpy has `np.block`, which is gives a nice shorthand for rep… eated concatenations. For example, say I want to make a block lower-triangle matrix:
$$
D = \begin{bmatrix} A & 0 \\
B & C \end{bmatrix}
$$
I can do:
```py
import numpy
A, B, C = np.ones((3, 4, 4))
B *= 2
C *= 3
zeros = np.zeros((4, 4))
D = np.block([[A, zeros], [B, C]])
print(D)
# Result:
[[1. 1. 1. 1. 0. 0. 0. 0.]
[1. 1. 1. 1. 0. 0. 0. 0.]
[1. 1. 1. 1. 0. 0. 0. 0.]
[1. 1. 1. 1. 0. 0. 0. 0.]
[2. 2. 2. 2. 3. 3. 3. 3.]
[2. 2. 2. 2. 3. 3. 3. 3.]
[2. 2. 2. 2. 3. 3. 3. 3.]
[2. 2. 2. 2. 3. 3. 3. 3.]]
```
Obviously you can do this in a million different ways, with `np.hstack`, `np.vstack`, `np.c_`, `np.r_r`, `np.concat`, `np.stack`, etc, etc. But this one is concise and readable.
In the content of pytensor, it's related to rewrites in the same vein as #1044. We could very easily break apart these block matrices and do the matmul in chunks, especially if we see there are zero matrices among the blocks.
opened 04:42PM - 19 Jan 24 UTC
### Description
`CholeskySolve` currently raises a non-implemented error on `… L_op`, but it could probably just use the generic solve gradient defined in the `SolveBase` from which it inherits.
In addition, we should investigate whether it is faster to use `CholeskySolve` in the `generic_solve_to_solve_triangular` rewrite. Currently we look for a lower-triangular tag and rewrite solve to `solve_trangular`. Instead, we could look for `Solve(Cholesky(A), b)` and rewrite the whole thing to `CholeskySolve(A, b)`.
opened 07:25PM - 10 Jan 24 UTC
feature request
backend compatibility
graph rewriting
performance
SciPy compatibility
linalg
vectorization
### Description
[This blog post ](https://www.johndcook.com/blog/2010/01/19/d… ont-invert-that-matrix/
)makes an important point about speed in cases where we want to repeatedly compute $A^{-1}b$ -- we should factorize the `A` matrix once, then recycle it for repeatedly solving. The specific factorization used in this case is LU factorization -- this can be verified by looking at the [function calls for dgesv](https://netlib.org/lapack/explore-html/d8/da6/group__gesv_ga831ce6a40e7fd16295752d18aed2d541.html) -- first dgetrf is called to perform an LU factorization, then dtrsm to actually solve the system. In scipy, these correspond to `linalg.lu_factor` and `linalg.lu_solve`. [Here's a gist ](https://gist.github.com/jessegrabowski/2ed3b1fde4729394f2cb401f8a76bc21) showing a very simple speed test, suggesting nice speedups by computing the LU decomposition once then recycling it many times for different $b$.
(Question: are other factorization schemes used in other solve cases (symmetric, positive-definite?)
All of this is relevant because of blockwise. This would be very low-hanging fruit for a rewrite in cases where a user has batch dimensions on the `b` matrix.
In addition, gradient implementations exist for both in pytorch ([lu_factor ](https://github.com/pytorch/pytorch/blob/9eb842cbd6f80b3d8c3e88bf70e389465c86eb6f/torch/csrc/autograd/FunctionsManual.cpp#L6413
), [lu_solve ](https://github.com/pytorch/pytorch/blob/9eb842cbd6f80b3d8c3e88bf70e389465c86eb6f/torch/csrc/autograd/FunctionsManual.cpp#L5658)) and jax ([lu_factor](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_factor.html), [lu_solve](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html)). I haven't found an academic reference that shows where their equations come from, though.
1100 would unlock more rewrites for matrix multiplication, but it isn’t necessarily an optimization itself.
There is also room for gains by working with square-root filters. This is especially nice since we sample covariance priors in Cholesky form anyway, so we’d actually never need to instantiate covariance matrix. I already wrote the filters, but actually using them is blocked by pytensor issue 1099, so we can’t compute their gradients.
Aside from all there, there’s also Chandrasekhar recursions instead of Kalman filtering in cases where it is permitted , which I think is the majority of interesting cases.
Finally, I think getting things over to numba is promising as well. There are some issues with the gradients currently being generated that prevents that, but I think it’s a long-term better solution than JAX, which is much more difficult to extend to specialized Ops that could be used to speed up computation, like solve_discrete_are and solve_discrete_lyapunov. I made an issue on the jax repo about these Ops but it was not warmly received. On the other hand, we already have numba dispatches for them.
4 Likes