What if you do it “by hand”, i.e.:
grads, _ = pytensor.scan(lambda expr, x: pytensor.grad(expr, x), sequences=[y_c.ravel(), y.ravel()])
jac_diag_terms = grads.reshape(-1, k)
logdets = pt.log(jac_diag_terms).sum(axis=-1)
Here I’m exploiting the the fact that the everything is elemwise, and only computing the diagonal of the (nk,nk) raveled jacobian. If you reshape that to (n,k), each row will be the main diagonal of one of the (n,k,k) things you get from the other procedure I pitched, and you can just compute the determinant “by hand”.