Everytime you call posterior predictive a lot of things happen, including compiling the forward function and converting your posterior trace into something easier to work with.
If you are calling posterior predictive many times you should reuse the two.
If you look at the source code of posterior predictive, you want to intercept these variables (you can just copy paste the code and make it return them to you)
These are the parsed trace and the compiled forward function. Then you can just set your shared data and feed the trace into the function at your will (you can also do dynamic thinning at this point by skipping some points in the trace)
Another thing you can try is to compile your forward function to JAX (by passing compile_kwargs=dict(mode="JAX”) which could be faster)
There are other solutions from this point on, but they become more complex so I would check if this gets you there.