Did you try the solution in this Frequently Asked Questions? Likely is that p = tt.where(p <= 0, 1e-50, p)
in your implementation breaks the gradient.
Did you try the solution in this Frequently Asked Questions? Likely is that p = tt.where(p <= 0, 1e-50, p)
in your implementation breaks the gradient.