A small update - I changed the call to scan to:
res1 = kron_vector_op(m)
res = kron_vector_op(res1)
and indeed it speeds things up a lot. I wonder if changing from scan to an actual for loop could help? Or are you not able to count how many times kron_vector_op is going to be called in general?
Without scan, about 80% of the time to calculate the log likelihood goes to Eigh, and about 95% of the time to calculate the gradients goes to EighGrad, so I guess there isn’t much that can be done to speed it up further.