Pytensor cannot handle RandomType SharedVariables in mode=JAX

Cannot use random variables in the PyTensor graph.
For example:

from pytensor.tensor.random.utils import RandomStream as RandomStreams
pytensor_rng = RandomStreams(cfg.numpy_rng.randint(2**30))
output =  pytensor_rng.normal(size = inp2.shape, scale = 1, dtype =pytensor.config.floatX)

Then, if variable output appears in the graph, I get this warning:

/home/paul.baggenstoss/miniconda3/lib/python3.9/site-packages/pytensor/link/jax/linker.py:28: UserWarning: The RandomType SharedVariables [RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FBEE4081200>)] will not be used in the compiled JAX graph. Instead a copy will be used.

Followed by an error:

  File "/home/paul.baggenstoss/miniconda3/lib/python3.9/site-packages/pytensor/link/jax/linker.py", line 47, in fgraph_convert
    input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

This works for me in the latest version of PyTensor.

import numpy as np

import pytensor
import pytensor.tensor as pt
from pytensor.tensor.random.utils import RandomStream as RandomStreams


pytensor_rng = RandomStreams(123)
inp = pt.matrix("inp")
output =  pytensor_rng.normal(size = inp.shape, scale=1, dtype=pytensor.config.floatX)

fn = pytensor.function([inp], output, mode="JAX")
fn(np.zeros((6, 6)))

What exactly is the code that fails on your end?
If you can make it as small and reproducible as possible, that would be better.

If you are working with PyMC and not just PyTensor we don’t recommend you use RandomStream. You can simply define PyMC variables and when we compile a function via compile_pymc we will automatically pick up the updates for you (you only need to handle those from Scan as they are special)

Below is the code I used. It gets the warning:

pytensor/link/jax/linker.py:28: UserWarning: The RandomType SharedVariables [RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FD8D548F660>)] will not be used in the compiled JAX graph. Instead a copy will be used.

I have PyTensor 2.11.1 and Jax 0.4.9 [gpu] and cudnn 8.8 , cuda 12.1

#-------------------------------------
import pytensor
import pytensor.tensor as T
import numpy as np
from pytensor.tensor.conv import conv2d


# ---------------------- constants  ------------------
bs=25
input_chan=1
nrow_in=28
ncol_in=28
k = 8
nchan_out=4
border_mode = 'valid'

from pytensor.tensor.random.utils import RandomStream as RandomStreams
pytensor_rng = RandomStreams(123)

# ---------------------- symbolic function ------------------
batch_input_shape = (bs,nrow_in,ncol_in)
filt_shape = (ncol_in,nchan_out)
xin=T.tensor3()
filt = T.matrix()
xin  = xin.reshape(batch_input_shape)
filt = filt.reshape(filt_shape)
z    = T.dot( xin, filt)

z = z +  pytensor_rng.normal(size = z.shape, scale=1, dtype=pytensor.config.floatX)
cfn  = pytensor.function(inputs=[xin,filt], outputs=[z], updates=None)

# ---------------------- python code ------------------
seed=15
numpy_rng = np.random.RandomState(seed)
W = np.ones(filt_shape)
x = numpy_rng.uniform(low=-1.0, high=1.0, size=batch_input_shape).astype(pytensor.config.floatX)
W = numpy_rng.uniform(low=-1.0, high=1.0, size=filt_shape).astype(pytensor.config.floatX)
zout = cfn(x,W)
print(zout)

The warnings are not an error. They will always happens with shared RNG variables in JAX, because we can’t use the same RNG objects with JAX and with other backends. They are not an error however.

Is something failing for you or you’re just getting the warnings?

The warning seemed ominous. As long as it does not affect the functionality, that’s OK, thanks

Yes, it just means we had to copy the contents of the RNG shared variables to a new one to not affect the original RNG shared variables. Random seeding will still work fine as well as updates.

This is only problematic if you wanted to control seeding manually between function evaluations by changing the value of your original RNG shared variables.

Thanks, I tested the functionality and it is OK. I guess this is closed.

There is still a problem. When using Random generation in Jax, I still get the error

File “/home/paul.baggenstoss/miniconda3/lib/python3.9/site-packages/pytensor/link/jax/linker.py”, line 52, in fgraph_convert
input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

I found out that the reason for this is that “old_inp_storage” is an array, and the python or numpy “index”
function of a list will fail when comparing arrays. The problem might be introduced by the use of a copy of the random array. For example,

import numpy as np
a=[1, 2, 3, np.asarray([1,2,3,4]), 5,6]
a.index(3)
2
a.index(a[3])
Traceback (most recent call last):
File “”, line 1, in
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

What do you think?

Maybe it would work if the array was passed to as_tensor_variable first…

Don’t know how to fix it, but I can detect when the error will happen.
I put this test in pytensor/jax/linker.py (starting at line 44)

            new_inp_storage = [new_inp.get_value(borrow=True)]
            storage_map[new_inp] = new_inp_storage
            old_inp_storage = storage_map.pop(old_inp)
            # ------------------- PMB ----------------
            for i,s in enumerate(input_storage): # PMB
                for j,d in enumerate(s):
                    if isinstance(d,np.ndarray): # PMB
                        print('Found Numpy array in graph')
            # ------------------- PMB ----------------

And I should add that this only detects an ndarray when I use the
random number generator in the graph.

I have a fix. The problem had nothing to do with the random number
generator itself. It was only that having the RNG in the graph caused
this to get to this line of code. As long as a numpy ndarray
exists in the graph, the list.index() function does not work. I have a suggested
fix, which I have tested. Here is the fix, inserted at line 44 pytensor/jax/linker.py

    Original Code:
             input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
    Suggested Fix:
             for i,s in enumerate(input_storage):
                   if not isinstance(s[0],np.ndarray):
                       if input_storage[i] == old_inp_storage:
                          input_storage[i] = new_inp_storage

I do not have enough knowledge of PyTensor to know if this fix is good.

Can you share a small reproducible example that leads to that index error?

Hi Ricardo, I did this in the Bug report

Oh, wait I got this issue mixed up with another. I’ll try to get a small example…

1 Like

I can’f find a simple example, only gives the error for my complicated software… do you know a way to force the graph to contain an np ndarray? That could then trigger the error, as long as RNG is also in the graph

I am not sure what you mean (I don’t yet understand the problem you’re seeing). Can you start from your full program and simplify as much as possible while still hitting the error?

In order for the error to occur, two things must happen at the same time.
In line 44 pytensor/jax/linker.py it is necessary that this triggers:

for i,s in enumerate(input_storage):
isinstance(s[0],np.ndarray):
Needs to get here

Which means the graph needs to contain a Numpy ndarray
And also that the test on line 28 also triggers. Is that exact enough?

The problem I have is that I have no idea what causes the graph to have an ndarray in it, so I don’t know how to reduce my very complex software. If I could have a clue, that would help

Yes I also don’t understand. The input_storage should be composed on nested lists. So somewhere a list is being replaced by an Array, when probably that array should be saved in the first entry of that list.

input_storage = [[None], [None], [None]]  # 3 inputs
# Initialize inputs to 0
for i in range(3):
  input_storage[i][0] = 0 

This only happens in the JAX backend?

Without an example it’s hard to see where the problem can be. If nothing else reproduces it, can you share your full program?