I have a similar use case to the post linked below - fitting a pymc model to many datasets in parallel.
My implementation for fitting in parallel is pretty much the same as in the linked post, but for sampling I’m trying to use pm.sampling_jax.sample_numpyro_nuts
, as I’ve found it to be nearly 10x faster than pm.sample
. I’ve configured jax to use the CPU and set chain_method='sequential'
.
This seems to compile fine but gets stuck indefinitely on the sampling stage - before getting to the progress bar:
Compiling...
Compilation time = 0:00:01.662194
Sampling...
I’d appreciate any advice on how to make this work, thanks!