Sampler not initialising

Hello all,

I’ve got some code that when evaluating outputs gives me what I expect, but unfortunately stumbles at the lass hurdle which is the sampler actually starting!

Attached is the most relevant part of the code and the code file itself:
Polished.py (14.1 KB)

def construct_BS_pymc_real(RV):
    #return pt.stack([pm.math.sqrt(RV), 1j*pm.math.sqrt(1-RV), 1j*pm.math.sqrt(1-RV), pm.math.sqrt(RV)]).reshape((2,2))
    return pt.stack([pm.math.sqrt(RV), 0, 0, pm.math.sqrt(RV)])
    
def construct_BS_pymc_imag(RV):
    #return pt.stack([pm.math.sqrt(RV), 1j*pm.math.sqrt(1-RV), 1j*pm.math.sqrt(1-RV), pm.math.sqrt(RV)]).reshape((2,2))
    return pt.stack([0, pm.math.sqrt(1-RV), pm.math.sqrt(1-RV), 0])
    
def construct_PS_pymc_real(RV):
    #return pt.stack([pm.math.exp(1j*RV/2),0,0,pm.math.exp(-1j*RV/2)]).reshape((2,2))
    return pm.math.sin(RV)
    
def construct_PS_pymc_imag(RV):
    #return pt.stack([pm.math.exp(1j*RV/2),0,0,pm.math.exp(-1j*RV/2)]).reshape((2,2))
    return pm.math.cos(RV)
    
def complex_matmult_pymc(A, B):
    """"Given `A = A + i * Ai` and `B = B + i * Bi` compute the real and imaginary comonents of `AB`"""
    A, Ai = A
    B, Bi = B
    C = pt.dot(A,B) - pt.dot(Ai, Bi)
    Ci = pt.dot(A, Bi) + pt.dot(Ai, B)
    return C, Ci
    
def circuit_list_to_matrix_pymc(feature):
    #Had to be quirky in construct_BS_pymc since set_subtensor doesn't like replacing 2D tensors or rows or even element by element (breaks down after one element) so had to return a flat tensor and reassign half of the array and then the other half.
    index=feature[2]-1
    if feature[0]=="BS":
        real=pt.eye(n=m,m=m)
        real=pt.set_subtensor(pt.flatten(real)[index:index+2],construct_BS_pymc_real(feature[1])[0:2])
        real=pt.set_subtensor(pt.flatten(real)[index+2:index+4],construct_BS_pymc_real(feature[1])[2:4])
        real=real.reshape((2,2))

        imag=pt.eye(n=m,m=m)
        imag=pt.set_subtensor(pt.flatten(imag)[0:2],construct_BS_pymc_imag(feature[1])[0:2])
        imag=pt.set_subtensor(pt.flatten(imag)[2:4],construct_BS_pymc_imag(feature[1])[2:4])
        imag=imag.reshape((2,2))
        #print(real.eval())
        #print(imag.eval())
        matrix=[(real,imag)]*N #To copy it N times to help matrix dotting across to get final U
        #print(matrix)
    if feature[0]=="PS":
        matrix=[]
        for _ in range(N):
            real=pt.eye(n=m,m=m)
            real=pt.set_subtensor(real[index,index],construct_PS_pymc_real(feature[1][0][_]))

            imag=pt.eye(n=m,m=m)
            imag=pt.set_subtensor(imag[index,index],construct_PS_pymc_imag(feature[1][0][_]))

            matrix.append((real,imag))
    #print(feature[0])
    #print(matrix)
    return matrix #return output matrix

with pm.Model():

    """
    Free parameters to infer
    """
    eta=pm.TruncatedNormal("eta",mu=0.5,sigma=0.05,lower=0.0,upper=1.0,initval=0.5)  #array of 
    theta=pm.Deterministic("theta",2*pt.arccos(pt.sqrt(eta)))
    #priors for conciseness
    #sd=a_dev
    a=pm.TruncatedNormal("a", mu=0, sigma=a_dev,lower=-np.pi,upper=np.pi,initval=0)  #array of priors for conciseness
    #sd=0.07
    b=pm.Normal("b", mu=0.7, sigma=0.7,initval=0.7) #array of priors for conciseness
       
    #Volt=pm.Normal("Volt",mu=V_2_dist,sigma=0.1)
    Volt=pm.Deterministic("Volt",pt.as_tensor(V))

    #below expression breaks down when there is just 1 a and b
    #phi=pm.Deterministic("phi",a[:,None]+b[:,None]*pm.math.sqr(Volt))
    """
    phi describes the different phase shifts for different experiments
    """
    phi=pm.Deterministic("phi",a+b*pm.math.sqr(Volt))

    circuit_list=[["BS",eta,1],["PS",phi,1]] #Need to reverse this order for it to be correct
        
    U_list = np.array([circuit_list_to_matrix_pymc(feature) for feature in circuit_list])
    #U_list is an array that I need to dot across but dot via complex_matmul function that I've defined
    #U=pt.nlinalg.matrix_dot(U_list) #Doesn't work raw since PS is a list of N matrices for the N experiments
        
        
    U=[] #To store final mode Unitaries: U=[(U1,U1i),(U2,U2i),...,(UN,UNi)]

    for i in range(N):
        rval = U_list[:,i][0]
        for a in U_list[:,i][1:]:
            rval=complex_matmult_pymc(rval,a)
            U.append(rval)
        
    """
    Indexing specific elements from each array
    """
    Utopreal=[elem[0][0][0] for elem in U] #top left element of each real matrix in U
    Utopimag=[elem[1][0][0] for elem in U] #top left element of each imag matrix in U
    Ubotreal=[elem[0][1][0] for elem in U] #bottom left element of each real matrix in 
    Ubotimag=[elem[1][1][0] for elem in U] #bottom left element of each imag matrix in U

    """
    Big slowdown when attempting to call sampling, text indicating initialisation doesn't even show up
    """
    #P=pm.math.stack([pt.nlinalg.norm(pm.math.stack([Utopreal,Utopimag],axis=-1),ord='fro',axis=-1)**2,pt.nlinalg.norm(pm.math.stack([Ubotreal,Ubotimag],axis=-1),ord='fro',axis=-1)**2])
    P=pm.math.stack([pm.math.sqr(Utopreal)+pm.math.sqr(Utopimag),pm.math.sqr(Ubotreal)+pm.math.sqr(Ubotimag)],axis=-1)
    #print(P.eval()) #Works as expected
        
    likelihood=pm.Multinomial("likelihood",n=C,p=P,shape=(N,m),observed=data)
        
    trace=pm.sample(draws=int(1e3), chains=4, cores=cpucount, return_inferencedata=True)

The problem is that when i run this code (full file also attached) it takes forever to initialise, the sampler on sampling being called which I’ve heard is because ‘lazy computation’ is done where it will only actually do computation when sampling is called. Now as to why it takes forever to initialise, I’m assuming it could be because of these repeated for loops and arrays that are making the graphs messy? I’ve drummed up some code to use pytensor.scan in case that would be more appropriate for this recursive complex matrix multiplication from the earlier post but I’m not sure on how to implement within this model block:

k = pt.iscalar("k")
A=pt.matrix("A")
Ai=pt.matrix("Ai")
B=pt.matrix("B")
Bi=pt.matrix("Bi")

def mat_compmult(A,Ai,B,Bi):
    #A, Ai = A
    #B, Bi = B
    C = pt.dot(A,B) - pt.dot(Ai, Bi)
    Ci = pt.dot(A, Bi) + pt.dot(Ai, B)
    return C,Ci

# Symbolic description of the result
result, updates = pytensor.scan(fn=mat_compmult,
                            outputs_info=[A,Ai],
                            non_sequences=[B,Bi],
                            n_steps=k)

final_result = result[-1]

power = pytensor.function(inputs=[A,Ai, B,Bi,k], outputs=final_result,
                      updates=updates)

print(power(np.ones((2,2)), np.ones((2,2)), np.ones((2,2)), np.ones((2,2)), 2))

Any help would be appreciated!