Specific model is slower in pymc that numpyro with same sampler

Hi all,

I have a model in numpyro and a model in pymc v4, both using the numpyro samplers. The numpyro one runs a chain in about 8 minutes for me, the pymc one about 14. I had some dev help from numpyro in building that model, so I’m wondering if anyone can see any ways to improve the and make it quicker.



Without looking at it, you can try passing check_bounds=False to the Model if you know the sampler will not propose invalid values.

Otherwise there are many things that might be different in the PyMC JAX graph and the NumPyro one. You can try to inspect those