Oh, I think I see. using np.max straight away reduces too many dimensions. Try t_stat=lambda x: np.max(x, axis=-1)
np.max
t_stat=lambda x: np.max(x, axis=-1)