The error does not occur for mode=FAST_RUN
Yes, maybe there is another bug that leads to that happening. Let me work on it, and maybe I can find out when the ndarray gets into the graph
Are you using a custom Op somewhere?
I was able to get a small program that causes the error:
#---------------- START -------------
import os
import scipy.io as scio
import numpy as np
import pytensor
import pytensor.tensor as T
from defaults import cfg, data, BACKEND
from pytensor.ifelse import ifelse
from pytensor.tensor.random.utils import RandomStream as RandomStreams
import collections
def create_pp_fn():
# create graph for preprocessing fn
inp=T.tensor4()
inp2 = inp
# add noise for regularization
inp2 = inp2 + cfg.pytensor_rng.normal(size = inp2.shape, scale = 1, dtype = pytensor.config.floatX)
pupdates = collections.OrderedDict()
pupdates[cfg.input_pp] = inp2.reshape(cfg.input_pp.shape)
pp_fn = pytensor.function(inputs=[], \
updates=pupdates, outputs=[],\
givens={inp: cfg.input_pp},on_unused_input='ignore',allow_input_downcast=True)
return pp_fn
cfg.seed=15
cfg.batchsize=25
cfg.numpy_rng = np.random.RandomState(cfg.seed)
cfg.pytensor_rng = RandomStreams(cfg.numpy_rng.randint(2**30))
cfg.input_chan=1
cfg.nonlin_in=0
cfg.ncol_in=28
cfg.nrow_in=28
batchshape = (cfg.batchsize,1,cfg.nrow_in, cfg.ncol_in)
initial_pp = np.zeros(batchshape, dtype=pytensor.config.floatX)
cfg.input_pp=pytensor.shared(initial_pp)
cfg.pp_fn = create_pp_fn()
#---------------- END -------------
Using this example, the āfixā that I suggested above works to avoid the error.
I need to remove the reference to āimport defaultsā. Iāll send a corrected file
OK, this now stands aloneā¦
import os
import scipy.io as scio
import numpy as np
import pytensor
import pytensor.tensor as T
from pytensor.ifelse import ifelse
from pytensor.tensor.random.utils import RandomStream as RandomStreams
import collections
class Object:
pass
cfg = Object()
def create_pp_fn():
# create graph for preprocessing fn
inp=T.tensor4()
inp2 = inp
# add noise for regularization
inp2 = inp2 + cfg.pytensor_rng.normal(size = inp2.shape, scale = 1, dtype = pytensor.config.floatX)
pupdates = collections.OrderedDict()
pupdates[cfg.input_pp] = inp2.reshape(cfg.input_pp.shape)
pp_fn = pytensor.function(inputs=[], \
updates=pupdates, outputs=[],\
givens={inp: cfg.input_pp},on_unused_input='ignore',allow_input_downcast=True)
return pp_fn
cfg.seed=15
cfg.batchsize=25
cfg.numpy_rng = np.random.RandomState(cfg.seed)
cfg.pytensor_rng = RandomStreams(cfg.numpy_rng.randint(2**30))
cfg.input_chan=1
cfg.nonlin_in=0
cfg.ncol_in=28
cfg.nrow_in=28
batchshape = (cfg.batchsize,1,cfg.nrow_in, cfg.ncol_in)
initial_pp = np.zeros(batchshape, dtype=pytensor.config.floatX)
cfg.input_pp=pytensor.shared(initial_pp)
cfg.pp_fn = create_pp_fn()
Here is a simplified case that reproduces the problem:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.random.utils import RandomStream
pytensor_rng = RandomStream(1)
batchshape = (25 ,1, 28, 28)
inp_shared = pytensor.shared(np.zeros(batchshape), name="inp_shared")
inp = pt.tensor4(name="inp")
out = inp + pytensor_rng.normal(size=inp.shape, scale=1)
updates = {inp_shared: out.reshape(inp_shared.shape)}
pp_fn = pytensor.function(
inputs=[],
outputs=[],
updates=updates,
givens={inp: inp_shared},
mode="JAX",
)
I opened an issue here: Function with only updates after givens fails in JAX mode Ā· Issue #314 Ā· pymc-devs/pytensor Ā· GitHub
Thanks for your investigative work! I opened a PR with a fix here: Fix bug in JAX cloning of RNG shared variables by ricardoV94 Ā· Pull Request #315 Ā· pymc-devs/pytensor Ā· GitHub
Thanks, very much appreciated. Once this and the dynamic indexing bug are fixed, I can port my software from Theano to Pytensor.
The dynamic slice is more of a new feature than a bugfix. It might take sometime for someone to take a look at it. If you are interested you could try to add that functionality yourself.
Totally understandable if you donāt want or canāt of course ![]()
Thanks. Yes, Iām not in a position to do that, Iām more of a user than a developer, but I can see with the error there is a built-in helpful suggestion:
" To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice
(JAX does not support dynamically sized arrays within JIT compiled functions)."
So maybe the new feature is relatively easy. ![]()
That error message is more helpful for users writing simple JAX functions than for a wrapper library like PyTensor which is trying to convert a graph into a JAX function.
For me, the advice is still some steps away from being actionable. I donāt even know if itās a limitation we can actually overcome without experimenting with it.
Thanks, what would you suggest as my easiest solution? I suppose I could develop a new custom function to copy data with dynamic indexing into a fixed array. That might be easier than adding the general feature to Pytensor.