My pytensor code causes an error when compiling with mode=JAX.
Workaround was to provide a single “dummy” output variable
as follows:
fn = pytensor.function(inputs=[...], outputs=[dummy])
With outputs=[ ]
, causes error. It can be that a function has no outpust is the updates is used to cause changes to shared variables. Then, the outputs might not be needed. The error is easy to reproduce, just compile any function with empty outputs variable.
This occurs only with mode=JAX
The error ends with:
File "/home/paul.baggenstoss/miniconda3/lib/python3.9/site-packages/pytensor/link/utils.py", line 184, in streamline
raise ValueError(
ValueError: ('Length of thunks and order must match', (1, 0))
```
Tnanks,
Paul