As_op vs pytensor for arbitrary matrix construction

Hi all, I’m in the middle of trying to write some generic code for a model which requires me to construct a unitary matrix from the way components are ordered on a chip. An arbitrary arrangement of these components means I would need to make this total unitary by combining the component unitaries in the appropriate (arbitrary) way. Doing this in numpy to simulate some data involves me replacing a part of an identity matrix with the component unitary which is determined by the ordering of components on the chip, so I make use of an external dictionary to track the component ordering and use numpy.pad to do the substitution, and just repeat multiplying out until I have covered all the chip components to calculate that total unitary.

So for the pymc part of it I want to estimate parameters governing these chips and thus need to express this unitary in the pymc model block. The arbitrary multiplication order of different unitaries I think should impede a simple vector algebra way of writing so I believe I can either do it via pytensor.scan for looping through the components and set_subtensor for replacing the part of the identity with the component unitary (but I believe pytensor.scan is quite computationally inefficient which isn’t ideal since in the long run I’m trying to develop a scalable methodology), so the alternative is using as_op (but then for this wrapper you need to define your itypes but then I don’t know if you can automate this based on just declaring I have e.g. n components so n inputs rather than itypes=[pt.dscalar,pt.dscalar…etc])?

So finally to reach the question is that is one option a better choice over the other? Since pt.scan is computationally inefficient which impedes scalability but then will as_op even be appropriate to use in this case? and if so is there a way to save writing out pt.dscalar over and over again in itypes since that wouldn’t be very scalable either!

Thanks in advance, I appreciate if anything is unclear so if this conversation develops then I will be happy to draft up some figures to elaborate what exactly I’m doing since trying to explain could get a bit involved and didn’t want to drown out the question! Please find attached my (very incomplete) code if it helps. (9.9 KB)

Well, scan gets a bad rap. I’m guilty of trash talking it. The truth is that it’s fine, you will just want to run your model in JAX or numba mode.

If you want to go with writing an Op, realize that 1) you don’t get gradients for free, and 2) you don’t get numba or JAX (and thus nutpie and numpyro/blackjax) for free. So it should really be considered a lesser option.

I had a quick look at your code but I wasn’t able to quickly grok it. At first blush didn’t seem like the loop to compute the U matrix was recursive as such? You do U = U @ new_U at the end, but this isn’t used in the construction of new_U at each step, so you could also just do something like U_list = [circuit_list_to_matrix(feature) for feature in circuit_list] followed by pt.linalg.matrix_dot(U_list).

1 Like

A numpy loop wrapped in an as_op is not going to be faster. When people say scan is inefficient it’s usually in reference to non looping “vectorized” code.

It also depends very much on what you’re doing and which backend you’re compiling to. Scan can be rather fast in numba and jax


Ahh OK thanks to you both then! I shall pursue the iterative approach then, and try and get numba/jax speedups from there.

Following up with implementing my problem, I went along the lines of the approach @jessegrabowski suggested but I think in defining the circuit_list_to_matrix function I have got an error that requires a rewrite.

Code block:

def construct_BS_pymc(RV):
        return pt.stack([pm.math.sqrt(RV), 1j*pm.math.sqrt(1-RV), 1j*pm.math.sqrt(1-RV), pm.math.sqrt(RV)]).reshape((2,2))
    def construct_PS_pymc(RV):
        return pt.stack([pm.math.exp(1j*phi/2),0,0,pm.math.exp(-1j*phi/2)]).reshape((2,2))
    def circuit_list_to_matrix_pymc(feature):
        Work in progress
        if feature[0]=="BS":
            #print(matrix[index:index+2][index:index+2].eval()) #Not what I want -> [1,1]
        if feature[0]=="PS":
            for _ in range(len(feature[1])):
        return matrix #return output matrix


My suspicion is that it is because set_subtensor doesn’t like advanced replacement like this where im trying to substitute a 2x2 matrix into a matrix that is the same size or larger. I just want to check if this suspicion is correct while I’m rewriting my code around it. Is this right and if so is there are a reason it hasn’t been implemented yet? if it is a case of just hasn’t got round to it and there isn’t some fundamental reason it wouldn’t work then I’d be happy to try and take a crack at the problem.

set_subtesnor requires something that looks like set_subtensor(x[idx], y). It’s complaining that you don’t seem to have something akin to x[idx]. This is because under the hood, set_subtensor is defined as an operation that takes x, idx, y as inputs

1 Like

Thanks Ricardo! I’ve been able to remedy it and get the code running (well, nearly running) by expressing set_subtensor in the way you suggested :grinning: