Sample_numpyro_nuts hangs when parallelizing over datasets with multiprocessing

I have a similar use case to the post linked below - fitting a pymc model to many datasets in parallel.

My implementation for fitting in parallel is pretty much the same as in the linked post, but for sampling I’m trying to use pm.sampling_jax.sample_numpyro_nuts, as I’ve found it to be nearly 10x faster than pm.sample. I’ve configured jax to use the CPU and set chain_method='sequential'.

This seems to compile fine but gets stuck indefinitely on the sampling stage - before getting to the progress bar:

Compilation time =  0:00:01.662194

I’d appreciate any advice on how to make this work, thanks!