Pytensor function needs at least one output for Jax

This works on my side:

import pytensor
import pytensor.tensor as pt

x = pytensor.shared(0.0, name="x")
y = pt.scalar("y")
fn = pytensor.function([y], [], updates={x: x+y}, mode="JAX")
fn(0), fn(2)
x.get_value()  # 2

And so does this:

x.set_value(0)
fn = pytensor.function([], [], updates={x: x+1}, mode="JAX")
fn(), fn()
x.get_value()  # 2

Can you share the minimal code that fails on your side?