Help wanted: Continue PyTorch backend for PyTensor

Anyone interested in contributing a PyTorch backend to PyTensor? I started a PR here: Experimental: PyTorch backend by twiecki · Pull Request #457 · pymc-devs/pytensor · GitHub

It doesn’t work yet but it’s a good starting point for someone motivated. There are failing tests which point to the current problems.

How did I create this? I copied the JAX-backend and asked GPT to port the code to PyTorch.

3 Likes

Why not directly Keras 3 instead of separate jax, TF, and PyTorch backends? Bonus: Apple MLX and Mojo backends in the future.

Isn’t keras very specific to NNs?

Not anymore, Keras 3 has Ops API that is pretty extensive already. And whatever distributions and other functions are missing looks possible to add directly to Keras, the process doesn’t look very bureaucratic: e.g., see Implement `binomial` and `beta` distribution functions in `keras.random` by KhawajaAbaid · Pull Request #18920 · keras-team/keras · GitHub. See also Probability Distributions in Keras-Core · Issue #18435 · keras-team/keras · GitHub.

Still seems a bit too much deep learning focused, but interesting.

From an eager standpoint, I don’t know if it would save us a lot of work with supporting a pytorch backend and we already have a pretty good support for JAX. We certainly wouldn’t want to have to contribute additional Ops to keras to be able to transpile from PyTensor to PyTorch.

Having said that it wouldn’t be too hard to to map, we could add an extension for it, perhaps as a small separate library.

What could also be nice, would be the other way around, to translate a Keras graph into a PyTensor graph

This would have a big upside for Apple Silicon folks, who don’t have access to GPU as PyTorch has support for Silicon GPU. Jax supposedly does but don’t see that coming to fruition for at least the next year because Apple does not open source their code for Jax-Metal, and the folks over at Jax are struggling to get it to work.

1 Like

If I understand correctly, all of the existing PyTensor backends work with NumPy arrays seamlessly (although in JAX, this doesn’t permit controlling the placement of these arrays in GPU memory).

It’s not the case for PyTorch, which can work with NumPy arrays only after they have been converted via torch.from_numpy().

@twiecki @ricardoV94 Should we deal with this by creating a GraphRewriter that scans the graph for TensorConstants with NumPy data and wraps it with torch.from_numpy()?

@leventov you should check how we transpile to the JAX backend. The idea is we write a python function with nested functions, one per PyTensor Op, that achieves the same outcome but using JAX functions. We also convert constant/non-constant inputs to valid JAX types when needed.

Hi I am Hung, 3rd year student interested in this project for GSoC 2024. I have had experience with PyTorch, which was required for my deep learning engineering study and started contributing to PyMC recently (my personal information is included in the proposal below).

I am writing a proposal for the project at this link: The proposal is submitted to PyMC organization under the NumFOCUS umbrella organization - Google Docs. I hope potential mentor @ricardoV94 @twiecki can take a look at this proposal and comment:

  1. If it is good enough to be accepted.
  2. If it is not good enough, how should I modify it so that it can be accept.

For any problem (including the jokes and the tone), you can comment on the document directly. I will edit it as quickly as possible. As the deadline (2 April) is coming near, I hope it can be promptly read :smiling_face_with_tear:. Thank you in advance!

Update:

  1. I have added detailed implementation details. I will continue adding in until I finish the timeline for implementation.
  2. I have submitted a working version of the proposal on GSoC 2024 page. The program allows me to change the proposal up until 3 April even after getting accepted. I will incrementalIy update the submitted proposal until the final draft. It will be great if I can receive early and consistent feedback :slight_smile: :+1: !
4 Likes

Hi,

It seems that I cannot edit my reply anymore. I have submitted the final draft of the proposal.. If potential mentor @ricardoV94 and other contributors have any feedback, I will edit the proposal and resubmit it!

@hangenyuu this looks like a great proposal, thanks for submitting!

1 Like

Congrats @Harshvir_Sandhu on getting this project. It was sad for me for not getting the project, but well, the decision is up to the PyMC team and @Harshvir_Sandhu has been contributing much more so no complaint from me. Looking forward to using PyMC/PyTensor with PyTorch in the future. If you need any help, also don’t hesitate to contact me!

The rough part of my full-time job has passed again, so I will be back contributing to PyMC this weekend.

4 Likes

Thank you @HangenYuu for your congratulatory message.
Will love to collaborate with you in the future stay connected with you.
Wish the best of luck in all your endeavours.
Best Regards

1 Like

I just saw this library this new library [Posteriors] (GitHub - normal-computing/posteriors: Uncertainty quantification with PyTorch), which is based on Pytorch. Not sure if you guys saw it. I am not sure if it is helpful, but thought I’d share.

Z.

2 Likes