Native JAX support in PyMC without wrappers?

There are plans but no way yet: Implement helper `@as_jax_op` to wrap JAX functions in PyTensor · Issue #537 · pymc-devs/pytensor · GitHub

For JAX sampling you just need the Op and the dispatch function, no grads or perform method. The last example in the blogpost is as succinct as it gets currently