Yes, this function is the result of another Discourse post:
class RootFinder(Op):
def __init__(self, t):
self.t= t
def make_node(self, D, a, b):
outputs = [at.vector(dtype='float64')]
return Apply(self, [D, a, b], outputs)
def perform(self, node, inputs, outputs_storage):
D, a, b = inputs
outputs_storage[0][0] = root_finder_loop_perform(D, a, b, self.t)
def grad(self, inputs, output_gradients):
D, a, b = inputs
x_list= self(D, a, b)
x_grad_list = a*D/((a**2*b)/4 - D**2*x_list**2) + 2*a*D**3*x_list**2/((a**2*b)/4 - D**2*x_list**2)**2
D_grad_list = 4*a*(a**2*b*x_list+ 4*D**2*x_list**3)/(a**2*b-4*D**2*x_list**2)**2
a_grad_list = -4*D*x_list*(a**2*b + 4*D**2*x_list**2)/(a**2*b-4*D**2*x_list**2)**2
b_grad_list = -4*a**3*D*x_list/(a**2*b - 4*D**2*x_list**2)**2
grad_D = at.dot((-D_grad_list/x_grad_list), output_gradients[0])
grad_a = at.dot((-a_grad_list/x_grad_list), output_gradients[0])
grad_b = at.dot((-b_grad_list/x_grad_list), output_gradients[0])
return grad_D, grad_a, grad_b
where, root_finder_loop_perform()
is a function that returns the first n roots of a function. D, a and b are random scalar variables.The grad()
function is expected to return three zero-dimensional outputs. I have confirmed that it works with grad_verify() and it runs smoothly with a single chain.
I then call the function via:
rootfinder = RootFinder(t=t)
x_list = rootfinder(D, a, b)
and use the n-dimensional x_list values in the next steps.
As I said previously, this works fine for a single chain, but doesn’t seem to like multiple chains.
Any help is highly appreciated!