Implementing Crank-Nicolson Method to Solve 1D Heat Diffusion

So I implemented this in PyTensor using scan like this

import os
os.environ['OMP_NUM_THREADS'] = '1'
import pymc as pm
import pytensor as at

def X_n_maker(d_factor, size):
    x_size = at.zeros(shape=(at.shape(size),at.shape(size)))
    Xn_1 = at.extra_ops.fill_diagonal_offset(x_size, d_factor, -1)
    
    Xn_2a = at.extra_ops.fill_diagonal_offset(x_size, 1-2.*d_factor, 0)
    Xn_2a1 = at.set_subtensor(Xn_2a[0,0],1-d_factor)
    Xn_2 = at.set_subtensor(Xn_2a1[-1,-1],1-d_factor)
    
    Xn_3 = at.extra_ops.fill_diagonal_offset(x_size, d_factor, 1)

    return Xn_1 + Xn_2 + Xn_3


def function_inner_loop(dt, x, ds, param):
        # Function 1
        d_factor = param[0]
        A_n = X_n_maker(-d_factor, x)
        B_n = X_n_maker(d_factor, x)

        x_1 = at.linalg.solve(A_n, at.dot(B_n, x), assume_a='sym')
        x_2 = at.set_subtensor(x_1[0], x_1[1] / (param[1]*ds + 1))
        x_3 = at.set_subtensor(x_2[-1], x_2[-2] / (param[2]*ds -1))       
        
        # Function 2
        x_next = function2(x_3, dt, params)
        
        return x_next


def function_outer_loop(dt, x, ds, param):
    
        result_inner, _ = pytensor.scan(fn=function_inner_loop,
                                                sequences=[dt],
                                                outputs_info=[x_0],
                                                non_sequences=[ds, param])    

        x_0 = x_0.dimshuffle('x',0)
        x_calc = at.concatenate([x_0, result_inner], axis=0)

        return x_calc 


### Part of main code
result, _ = pytensor.scan(fn=function_outer_loop, sequences=[samples], outputs_info=[at.zeros(shape=(at.shape(time), at.shape(x_0)], non_sequences=[ds, param])    


x_result = at.concatenate([x_0, result], axis=0)

Here, x is an array with length 200 and params are pymc random variables and sample is an array with length 4, so that I can loop over all four samples.

This implementation runs fine, but is quite slow: When I run the same code, but skip everything in # Function 1 , a chain of 100 samples takes about 1 sec., whereas, if I include Function 1, it took about 10 min. and all the CPUs are maxed out.
This reminded me of a post from @aseyboldt from a while ago.
I have therefore started using os.environ['OMP_NUM_THREADS'] = '1' as suggested here.

This brings it down to 2 min. run time for 100 samples, but still the single CPU that is used is maxed out (if I used ‘5’, then all 5 are maxed out). I feel like there may be something else wrong with this code or the way that pytensor parallelizes at.linalg.solve and at.dot.

Any help is highly appreciated!