Pytensor function needs at least one output for Jax

My pytensor code causes an error when compiling with mode=JAX.
Workaround was to provide a single “dummy” output variable
as follows:

   fn = pytensor.function(inputs=[...], outputs=[dummy])

With outputs=[ ], causes error. It can be that a function has no outpust is the updates is used to cause changes to shared variables. Then, the outputs might not be needed. The error is easy to reproduce, just compile any function with empty outputs variable.

This occurs only with mode=JAX

The error ends with:

File "/home/paul.baggenstoss/miniconda3/lib/python3.9/site-packages/pytensor/link/utils.py", line 184, in streamline
    raise ValueError(
ValueError: ('Length of thunks and order must match', (1, 0))
```
Tnanks,
Paul

By the way, this is with PyTensor 2.9.1

This works on my side:

import pytensor
import pytensor.tensor as pt

x = pytensor.shared(0.0, name="x")
y = pt.scalar("y")
fn = pytensor.function([y], [], updates={x: x+y}, mode="JAX")
fn(0), fn(2)
x.get_value()  # 2

And so does this:

x.set_value(0)
fn = pytensor.function([], [], updates={x: x+1}, mode="JAX")
fn(), fn()
x.get_value()  # 2

Can you share the minimal code that fails on your side?

Updating PyTensor solved it, thanks

1 Like