For CPU sampling of huge models that do not involve extremely high precision calculations (thinking mostly about matrix decomposition here), you can set pytensor.config.floatX = 'float32' for some “free” speedups as well (you might have to fuss with setting .astype throughout your model though; I think @ricardoV94 was working on some helpers to make this easier).
Pedantic sidebar that nobody asked for: there’s no PyMC numba sampler. When you set compile_kwargs={'mode':'NUMBA'}, pytensor compiles a numba.njit decorated logp function, which is then called inside the same old PyMC NUTS sampler, which is implemented in pure python. Nutpie in numba mode also calls the same jit compiled logp function, but then uses a NUTS sampler written in rust to do the sampling loop. Nutpie also has a smarter tuning strategy, so you can get away with fewer warmup steps, which also translates into speedup.
Setting aside the actual logp and gradients of your model, NUTS isn’t a actually computationally heavy algorithm. That’s why we can get away with a boring old python implementation.
Regarding MLX, the package is very deep-learning focused and does not have many of the specialized statistics functions we need (like gammaincused by the HurdleGamma !). I’m also very excited by the possibility of an MLX backend, but it’s going to be a heavy lift to get 100% coverage. Help wanted ![]()