Custom functions in Jax?

I think I read somewhere that this is totally a thing, and numpy too, maybe?