This works on my side:
import pytensor
import pytensor.tensor as pt
x = pytensor.shared(0.0, name="x")
y = pt.scalar("y")
fn = pytensor.function([y], [], updates={x: x+y}, mode="JAX")
fn(0), fn(2)
x.get_value() # 2
And so does this:
x.set_value(0)
fn = pytensor.function([], [], updates={x: x+1}, mode="JAX")
fn(), fn()
x.get_value() # 2
Can you share the minimal code that fails on your side?