I understood your dissatisfaction was about the input, that you can’t write a PyMC model the way you would write a STAN model, line by line. Not about the output.
PyMC is perfectly capable of outputting to you the model conditional/joint densities in jax code (which you called translated into JAX?) which you can then transform/vmap/pmap as you please. That’s how we interop with the jax samplers (ot rust via the numba backend)