Native JAX support in PyMC without wrappers?

What sampler are you planning to use? The PyMC sampler doesn’t have any options like that.

If you want to use PyMC but sample via numpyro you can just pass any optional kwargs to it (via nuts_sampler_kwargs). The gradients, regardless of forward or backward, are taken by JAX anyway so there’s nothing else for you to do on the PyTensor side