Pytensor cannot handle RandomType SharedVariables in mode=JAX

The error does not occur for mode=FAST_RUN
Yes, maybe there is another bug that leads to that happening. Let me work on it, and maybe I can find out when the ndarray gets into the graph

Are you using a custom Op somewhere?

I was able to get a small program that causes the error:

#---------------- START -------------
import os
import scipy.io as scio
import numpy as np
import pytensor
import pytensor.tensor as T
from defaults import cfg, data, BACKEND
from pytensor.ifelse import ifelse
from pytensor.tensor.random.utils import RandomStream as RandomStreams
import collections

def create_pp_fn():
# create graph for preprocessing fn
inp=T.tensor4()
inp2 = inp

# add noise for regularization
inp2 = inp2 + cfg.pytensor_rng.normal(size = inp2.shape, scale = 1, dtype = pytensor.config.floatX)

pupdates = collections.OrderedDict()
pupdates[cfg.input_pp] =  inp2.reshape(cfg.input_pp.shape)
pp_fn = pytensor.function(inputs=[], \
        updates=pupdates, outputs=[],\
        givens={inp: cfg.input_pp},on_unused_input='ignore',allow_input_downcast=True)

return pp_fn

cfg.seed=15
cfg.batchsize=25
cfg.numpy_rng = np.random.RandomState(cfg.seed)
cfg.pytensor_rng = RandomStreams(cfg.numpy_rng.randint(2**30))
cfg.input_chan=1
cfg.nonlin_in=0
cfg.ncol_in=28
cfg.nrow_in=28
batchshape = (cfg.batchsize,1,cfg.nrow_in, cfg.ncol_in)

initial_pp = np.zeros(batchshape, dtype=pytensor.config.floatX)
cfg.input_pp=pytensor.shared(initial_pp)
cfg.pp_fn = create_pp_fn()
#---------------- END -------------

Using this example, the ā€˜fixā€™ that I suggested above works to avoid the error.

I need to remove the reference to ā€˜import defaultsā€™. Iā€™ll send a corrected file

OK, this now stands aloneā€¦

import os
import scipy.io as scio
import numpy as np
import pytensor
import pytensor.tensor as T
from pytensor.ifelse import ifelse
from pytensor.tensor.random.utils import RandomStream as RandomStreams
import collections
class Object:
    pass
cfg = Object()

def create_pp_fn():
    # create graph for preprocessing fn
    inp=T.tensor4()
    inp2 = inp

    # add noise for regularization
    inp2 = inp2 + cfg.pytensor_rng.normal(size = inp2.shape, scale = 1, dtype = pytensor.config.floatX)

    pupdates = collections.OrderedDict()
    pupdates[cfg.input_pp] =  inp2.reshape(cfg.input_pp.shape)
    pp_fn = pytensor.function(inputs=[], \
            updates=pupdates, outputs=[],\
            givens={inp: cfg.input_pp},on_unused_input='ignore',allow_input_downcast=True)

    return pp_fn

cfg.seed=15
cfg.batchsize=25
cfg.numpy_rng = np.random.RandomState(cfg.seed)
cfg.pytensor_rng = RandomStreams(cfg.numpy_rng.randint(2**30))
cfg.input_chan=1
cfg.nonlin_in=0
cfg.ncol_in=28
cfg.nrow_in=28
batchshape = (cfg.batchsize,1,cfg.nrow_in, cfg.ncol_in)

initial_pp = np.zeros(batchshape, dtype=pytensor.config.floatX)
cfg.input_pp=pytensor.shared(initial_pp)
cfg.pp_fn = create_pp_fn()

Here is a simplified case that reproduces the problem:

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.random.utils import RandomStream

pytensor_rng = RandomStream(1)

batchshape = (25 ,1, 28, 28)
inp_shared = pytensor.shared(np.zeros(batchshape), name="inp_shared")

inp = pt.tensor4(name="inp")
out = inp + pytensor_rng.normal(size=inp.shape, scale=1)

updates = {inp_shared: out.reshape(inp_shared.shape)}

pp_fn = pytensor.function(
    inputs=[], 
    outputs=[], 
    updates=updates, 
    givens={inp: inp_shared},
    mode="JAX",
)

I opened an issue here: Function with only updates after givens fails in JAX mode Ā· Issue #314 Ā· pymc-devs/pytensor Ā· GitHub

1 Like

Thanks for your investigative work! I opened a PR with a fix here: Fix bug in JAX cloning of RNG shared variables by ricardoV94 Ā· Pull Request #315 Ā· pymc-devs/pytensor Ā· GitHub

1 Like

Thanks, very much appreciated. Once this and the dynamic indexing bug are fixed, I can port my software from Theano to Pytensor.

The dynamic slice is more of a new feature than a bugfix. It might take sometime for someone to take a look at it. If you are interested you could try to add that functionality yourself.

Totally understandable if you donā€™t want or canā€™t of course :slight_smile:

Thanks. Yes, Iā€™m not in a position to do that, Iā€™m more of a user than a developer, but I can see with the error there is a built-in helpful suggestion:

" To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice
(JAX does not support dynamically sized arrays within JIT compiled functions)."

So maybe the new feature is relatively easy. :grinning:

That error message is more helpful for users writing simple JAX functions than for a wrapper library like PyTensor which is trying to convert a graph into a JAX function.

For me, the advice is still some steps away from being actionable. I donā€™t even know if itā€™s a limitation we can actually overcome without experimenting with it.

Thanks, what would you suggest as my easiest solution? I suppose I could develop a new custom function to copy data with dynamic indexing into a fixed array. That might be easier than adding the general feature to Pytensor.