Native JAX support in PyMC without wrappers?

I’ve written a model in JAX that involves solving ODEs (using the diffrax JAX package) and I want to infer the parameters of the ODEs. I am deciding between numpyro and PyMC – I prefer PyMC (have used it before) but it seems so complicated to use with JAX, requiring many wrappers – for example: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

Numpyro seems much easier with native support for JAX functions: Example: Predator-Prey Model — NumPyro documentation

Are there plans (or existing functionality) to make it just as easy to use JAX in PyMC?

As a backup I’m considering just going with numpyro but then reading/analyzing the output with arviz…

There are plans but no way yet: Implement helper `@as_jax_op` to wrap JAX functions in PyTensor · Issue #537 · pymc-devs/pytensor · GitHub

For JAX sampling you just need the Op and the dispatch function, no grads or perform method. The last example in the blogpost is as succinct as it gets currently

Thanks @ricardoV94 and what if for that diffrax example in the blog post I wanted to use the forward mode (jvp) instead of reverse-mode (vjp)? Can PyMC use forward mode gradients?

(numpyro has an option to use forward mode: Markov Chain Monte Carlo (MCMC) — NumPyro documentation)

What sampler are you planning to use? The PyMC sampler doesn’t have any options like that.

If you want to use PyMC but sample via numpyro you can just pass any optional kwargs to it (via nuts_sampler_kwargs). The gradients, regardless of forward or backward, are taken by JAX anyway so there’s nothing else for you to do on the PyTensor side