As I understand it, pytensor is a fork of aesara, which is a fork of theano. I have written scripts that convert theano code to either aesara or pytensor. I have a 2d convolution problem which gets the following execution time for the same problem (same code) Theano-GPU (0.7 sec), Aesara-CPU 3.4 sec, Pytensor-CPU 37 sec. This is practically unusable for me. So as Theano gets further forked, processing time gets further f**ked. I understand the technical reasons, aesara dropped the gpuarray backend, and pytensorr seems to have dropped the GEMM/Blas optimizations. In fact, I dug down deep and it seems pytensor is using convolution in abstract_conv.py, which is just a bunch of nested python loops. No wonder it is so slow. So my question is, is there any plan to add optimizations to Conv Ops?
PyTensor didn’t drop BLAS/GEMM support and your difference in performance from Aesara is very unexpected. Can you confirm PyTensor can see the Blas bindings? This often fails when installing via pip instead of conda-forge.
If you can share a minimal code someone can try and replicate it. If performance deteriorated that much we would consider it a bug an try to fix it.
Regarding GPU, that was deemed too challenging to support properly with our limited resources. However, PyTensor functions can be easily transpiled into JAX which has great native support for GPUs and TPUs
Many thanks for the info and suggestions. That gives me hope and I’ll track down the problem with BLAS/Gemm. I did use mode=JAX, and I do have jax[gpu] correctly installed, but it complained that JAX-ified Ops are not available for many of the Ops I use. I’ll report back. Thanks again
Feel free to open an issue on our repository listing any missing JAX Ops that you need: GitHub - pymc-devs/pytensor: PyTensor is a fork of Aesara -- a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
You may even choose to contribute the functionality yourself. Here is a recent example that shows how easy it can be: Add jax implementation of `pt.linalg.pinv` by jessegrabowski · Pull Request #294 · pymc-devs/pytensor · GitHub
So, I re-installed PyTensor with conda -c conda-forge, and yes, I see now that the c-compiler and all dependencies are properly installed including (I think) GEMM/Blas And, I have seen the speedup of the execution time for the case without 2D convolution. But, for 2D convolution, there is still a huge difference. I notice that in AESARA, there is aesara/tensor/nnet/conv.py which has optimized convolution drawn from
from scipy.signal._sigtools import _convolve2d
and implemented as an OpenMP Op.
But in pytensor, this has been apparently removed. The convolution appears to be done brute-force in abstract_conv.py
Am I missing something here?
Thanks in advance
Yes, you’re correct. Even in Aesara it is marked as deprecated: aesara/__init__.py at 4b266671f5b60ba2cf9e435248febfa07f632824 · aesara-devs/aesara · GitHub
In general we are trying to move away from the old C_code to minimize maintenance burden and betting on the Numba/JAX backends.
But we did want to keep the Conv Ops around because they are still so useful. I think there was an error in this PR: Remove deprecated modules by ferrine · Pull Request #111 · pymc-devs/pytensor · GitHub
Where we kept the
abstract_conv, but not the
conv module and the rewrites that replace the former (slow) by the latter (fast). I will open an issue on our repo to bring them back. Thanks for flagging it!
Thanks, that would be helpful.
Also, not that there is a bug in the existing abstract_conv.py, where I had to add the cast in line 682:
filters = pytensor.tensor.cast(filters,dtype=input.dtype)
Without this, abstract_conv will upcast data to float64, even if the input and outout
want to have float32.
Do you want to open a pull request to fix that in the source code?
Hi, so far I have never done a pull request. Not sure how. I’ll read up on that, or I don’t mind if you add that one line for me #:^)
Sure, in case you’re interested we have a tutorial for PyMC, should be similar for PyTensor: Pull request step-by-step — PyMC 5.3.1 documentation
OK, I’ll read up, but not sure when I will have time. Not sure I understood, can you do that 1 line for me? I’d appreciate that.
If I find the time