Hitting a weird error to do with RNGs in Scan in a custom function inside a Potential

What if you do it “by hand”, i.e.:

grads, _ = pytensor.scan(lambda expr, x: pytensor.grad(expr, x), sequences=[y_c.ravel(), y.ravel()])
jac_diag_terms = grads.reshape(-1, k)
logdets = pt.log(jac_diag_terms).sum(axis=-1)

Here I’m exploiting the the fact that the everything is elemwise, and only computing the diagonal of the (nk,nk) raveled jacobian. If you reshape that to (n,k), each row will be the main diagonal of one of the (n,k,k) things you get from the other procedure I pitched, and you can just compute the determinant “by hand”.