I installed PYMC 5 this morning (which uses Pytensor as the backend). My machine is an M1 Max with 64 GB memory. I decided to test the default sampler vs Jax-Numpyro and Jax-Blackjax. Prior to PYMC5, the default sampler would always be much slower than Jax.
I don’t know what dark magic the PYMC devs are doing (maybe the porting to Numba?)/is this downstream from Aesara but the default sampler was better than both JAX-based samplers (Pytensor=3min20s; Numpyro=3min58s; and BlackJax=5min15s).
I had to rerun a couple of times to confirm because I was initially incredulous. But Pytensor was faster, at least for this model. I have attached an image of the sampling with the different backends.
I will test with other models but if this holds, that would be very impressive.
This is most probably an effect of Python 3,11,
While we’re making small improvements all the time, I’m not aware of anything between v4 and v5 that would explain such a big jump.
I just ran one of my models with PyMC 4.2.2 under Python 3.10.8 vs. PyMC v5.0.0 under Python 3.9.15 and the v5 run was only about 5 % faster.