You shouldn’t need to be passing around rngs, because your function is completely deterministic.
You should not be getting non-square matrices inside the log_jcd_elem function. You want to arrange things to be something like this:
def log_jcd_elem(A, x):
jac = tg.jacobian(A, x) #(2, 2)
sign, logdet = pt.linalg.slogdet(jac)
return logdet
Scan iterates over the left-most dimension of the inputs, so as long as f_inv_x is 10, 2 and x is 10, 2, you just need to:
jacs, _ = pytensor.scan(log_jcd_elem, sequences=[f_inv_x, x])
If that doesn’t work, I’ll take a more careful look at the gist and see if I can figure something out.