In the JAX guide you have examples for non scalar outputs in the Grad Ops. There’s nothing special about non-scalar outputs, you just have yo make sure you define them correctly.
Regardless, check my comment above. I think you are confusing what we mean by Potential summing the output. The independent logps (regardless of how you bring them to the model) will always be summed by PyMC and that’s not a problem. It isnt “confusing” or “mixing them”, since the graph is symbolic and gradients will be properly partitioned wrt to each input even after summing.