FFT-based Cholesky decomposition for GPs on a grid?

Does PyMC have any built-in ability to, in the case of GP models evaluated on a grid, perform the cholesky decomposition (usually O(n^3) ) using the FFT method (which achieves O(n*log*n)))? Or has anyone hand-implemented this?

1 Like

I haven’t seen it before, but there is an FFT sub-module in pytensor that can do FFT/inverse FFT. I’d be interested to see what the code would look like if you wrote up how it would work.

It’s not an actively maintained area of the codebase, so I’m not sure if you can compile graphs with FFT to jax/numba. Id be interested in benchmarks between pytensor+chol, jax+chol, pytensor+fft-chol, and maybe a jax native fft-chol (if one exists). There’s also a package that lets you numba njit functions with scipy.signal.fft, so that might be interesting to add to the mix too?

Anyway I’m very interested!