'sample_posterior_predictive( predictions= True)' takes a long time to "load" & JAX?

Hello! I’ve noticed that when I try to do inference on out of sample data, that the CPU useage goes to 100 and sample_posterior_predictive(predictions= True) takes a long time to “load”? The model is quite big. I’ve got it defined before to asking it to sample the prior.

It samples fine (and quickly! <1s) but the initial load time takes a good 30+ seconds (when no progress bar is shown). It’s quicker on my laptop than on a cloud hosted server, but I get the same type of behavior on both.

I thought this could be the lazy execution from az.from_netcdf() since I am loading a trace. But setting execution to ‘eager’ doesn’t seem to help.

I also dug around in the source to see if any of the other samplers had a method with a predictions arg, but could not find anything.

Is there any way to do this predictive out of sample data? Do I just resample with JAX after I load my trace and set my data containers? Wouldn’t that destroy anything that I’ve done with pm.sample()?

Ideally I’d like to do inference on out-of-sample data as fast as possible, can JAX help?

Cheers.