Very slow 2d convolution with Pytensor

Feel free to open an issue on our repository listing any missing JAX Ops that you need: GitHub - pymc-devs/pytensor: PyTensor is a fork of Aesara -- a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.

You may even choose to contribute the functionality yourself. Here is a recent example that shows how easy it can be: Add jax implementation of `pt.linalg.pinv` by jessegrabowski · Pull Request #294 · pymc-devs/pytensor · GitHub