Only one chain running, three are stuck

Yes, this function is the result of another Discourse post:

class RootFinder(Op):
        
    def __init__(self, t):
        
        self.t= t
        
    def make_node(self, D, a, b):

        outputs = [at.vector(dtype='float64')]
        return Apply(self, [D, a, b], outputs)
    
    def perform(self, node, inputs, outputs_storage):
        D, a, b = inputs
        
        outputs_storage[0][0] = root_finder_loop_perform(D, a, b, self.t)
    
    def grad(self, inputs, output_gradients):
        D, a, b = inputs
        x_list= self(D, a, b)
        
        x_grad_list = a*D/((a**2*b)/4 - D**2*x_list**2) + 2*a*D**3*x_list**2/((a**2*b)/4 - D**2*x_list**2)**2
        D_grad_list = 4*a*(a**2*b*x_list+ 4*D**2*x_list**3)/(a**2*b-4*D**2*x_list**2)**2
        a_grad_list = -4*D*x_list*(a**2*b + 4*D**2*x_list**2)/(a**2*b-4*D**2*x_list**2)**2
        b_grad_list = -4*a**3*D*x_list/(a**2*b - 4*D**2*x_list**2)**2
        
        grad_D = at.dot((-D_grad_list/x_grad_list), output_gradients[0])
        grad_a = at.dot((-a_grad_list/x_grad_list), output_gradients[0])
        grad_b = at.dot((-b_grad_list/x_grad_list), output_gradients[0])
       
        return grad_D, grad_a, grad_b

where, root_finder_loop_perform() is a function that returns the first n roots of a function. D, a and b are random scalar variables.The grad() function is expected to return three zero-dimensional outputs. I have confirmed that it works with grad_verify() and it runs smoothly with a single chain.

I then call the function via:

rootfinder = RootFinder(t=t)
x_list = rootfinder(D, a, b)

and use the n-dimensional x_list values in the next steps.
As I said previously, this works fine for a single chain, but doesn’t seem to like multiple chains.
Any help is highly appreciated!