FFT-based Cholesky decomposition for GPs on a grid?

I haven’t seen it before, but there is an FFT sub-module in pytensor that can do FFT/inverse FFT. I’d be interested to see what the code would look like if you wrote up how it would work.

It’s not an actively maintained area of the codebase, so I’m not sure if you can compile graphs with FFT to jax/numba. Id be interested in benchmarks between pytensor+chol, jax+chol, pytensor+fft-chol, and maybe a jax native fft-chol (if one exists). There’s also a package that lets you numba njit functions with scipy.signal.fft, so that might be interesting to add to the mix too?

Anyway I’m very interested!