The error does not occur for mode=FAST_RUN
Yes, maybe there is another bug that leads to that happening. Let me work on it, and maybe I can find out when the ndarray gets into the graph
Are you using a custom Op somewhere?
I was able to get a small program that causes the error:
#---------------- START -------------
import os
import scipy.io as scio
import numpy as np
import pytensor
import pytensor.tensor as T
from defaults import cfg, data, BACKEND
from pytensor.ifelse import ifelse
from pytensor.tensor.random.utils import RandomStream as RandomStreams
import collections
def create_pp_fn():
# create graph for preprocessing fn
inp=T.tensor4()
inp2 = inp
# add noise for regularization
inp2 = inp2 + cfg.pytensor_rng.normal(size = inp2.shape, scale = 1, dtype = pytensor.config.floatX)
pupdates = collections.OrderedDict()
pupdates[cfg.input_pp] = inp2.reshape(cfg.input_pp.shape)
pp_fn = pytensor.function(inputs=[], \
updates=pupdates, outputs=[],\
givens={inp: cfg.input_pp},on_unused_input='ignore',allow_input_downcast=True)
return pp_fn
cfg.seed=15
cfg.batchsize=25
cfg.numpy_rng = np.random.RandomState(cfg.seed)
cfg.pytensor_rng = RandomStreams(cfg.numpy_rng.randint(2**30))
cfg.input_chan=1
cfg.nonlin_in=0
cfg.ncol_in=28
cfg.nrow_in=28
batchshape = (cfg.batchsize,1,cfg.nrow_in, cfg.ncol_in)
initial_pp = np.zeros(batchshape, dtype=pytensor.config.floatX)
cfg.input_pp=pytensor.shared(initial_pp)
cfg.pp_fn = create_pp_fn()
#---------------- END -------------
Using this example, the āfixā that I suggested above works to avoid the error.
I need to remove the reference to āimport defaultsā. Iāll send a corrected file
OK, this now stands aloneā¦
import os
import scipy.io as scio
import numpy as np
import pytensor
import pytensor.tensor as T
from pytensor.ifelse import ifelse
from pytensor.tensor.random.utils import RandomStream as RandomStreams
import collections
class Object:
pass
cfg = Object()
def create_pp_fn():
# create graph for preprocessing fn
inp=T.tensor4()
inp2 = inp
# add noise for regularization
inp2 = inp2 + cfg.pytensor_rng.normal(size = inp2.shape, scale = 1, dtype = pytensor.config.floatX)
pupdates = collections.OrderedDict()
pupdates[cfg.input_pp] = inp2.reshape(cfg.input_pp.shape)
pp_fn = pytensor.function(inputs=[], \
updates=pupdates, outputs=[],\
givens={inp: cfg.input_pp},on_unused_input='ignore',allow_input_downcast=True)
return pp_fn
cfg.seed=15
cfg.batchsize=25
cfg.numpy_rng = np.random.RandomState(cfg.seed)
cfg.pytensor_rng = RandomStreams(cfg.numpy_rng.randint(2**30))
cfg.input_chan=1
cfg.nonlin_in=0
cfg.ncol_in=28
cfg.nrow_in=28
batchshape = (cfg.batchsize,1,cfg.nrow_in, cfg.ncol_in)
initial_pp = np.zeros(batchshape, dtype=pytensor.config.floatX)
cfg.input_pp=pytensor.shared(initial_pp)
cfg.pp_fn = create_pp_fn()
Here is a simplified case that reproduces the problem:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.random.utils import RandomStream
pytensor_rng = RandomStream(1)
batchshape = (25 ,1, 28, 28)
inp_shared = pytensor.shared(np.zeros(batchshape), name="inp_shared")
inp = pt.tensor4(name="inp")
out = inp + pytensor_rng.normal(size=inp.shape, scale=1)
updates = {inp_shared: out.reshape(inp_shared.shape)}
pp_fn = pytensor.function(
inputs=[],
outputs=[],
updates=updates,
givens={inp: inp_shared},
mode="JAX",
)
I opened an issue here: Function with only updates after givens fails in JAX mode Ā· Issue #314 Ā· pymc-devs/pytensor Ā· GitHub
Thanks for your investigative work! I opened a PR with a fix here: Fix bug in JAX cloning of RNG shared variables by ricardoV94 Ā· Pull Request #315 Ā· pymc-devs/pytensor Ā· GitHub
Thanks, very much appreciated. Once this and the dynamic indexing bug are fixed, I can port my software from Theano to Pytensor.
The dynamic slice is more of a new feature than a bugfix. It might take sometime for someone to take a look at it. If you are interested you could try to add that functionality yourself.
Totally understandable if you donāt want or canāt of course
Thanks. Yes, Iām not in a position to do that, Iām more of a user than a developer, but I can see with the error there is a built-in helpful suggestion:
" To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice
(JAX does not support dynamically sized arrays within JIT compiled functions)."
So maybe the new feature is relatively easy.
That error message is more helpful for users writing simple JAX functions than for a wrapper library like PyTensor which is trying to convert a graph into a JAX function.
For me, the advice is still some steps away from being actionable. I donāt even know if itās a limitation we can actually overcome without experimenting with it.
Thanks, what would you suggest as my easiest solution? I suppose I could develop a new custom function to copy data with dynamic indexing into a fixed array. That might be easier than adding the general feature to Pytensor.