Hi,
I have installed pymc on a new linux machine (a server I don’t manage) and sampling is much slower than I was expecting/get on other machines.
Is there any way to check whether pymc is using the right optimised libraries - and in case which ones should I be looking for?
For example I get WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions
but this doesn’t seem to be much of an issue on other machines.
Thanks!
It’s difficult to answer in the abstract. In general, the recommendation is to follow the official installation instructions. If you have done so and are still getting the BLAS warnings (which are always a bad sign), then you can reply with the output of conda list
and conda config --show-sources
.
My understanding is that if you have an intel process, you need to make sure you have the following packages:
mkl
mkl-service
libblas=*=*mkl*
If it’s a mac ARM64 you need libblas=*=*accelerate*
(and none of the mkl
related support packages). For AMD processors, I have no idea. The list of all supported conda-forge BLAS installations is here.
…but always follow the official instructions, because if you just use conda in the first place it handles this all for you automatically.
Thank you both.
I installed everything through the official instructions. I only now realised this machine is amd, rather than intel so no mkl.
It should be using openblas (libopenblas is installed with pymc) so maybe that causes the difference in speed.
Maybe @maresb has more insight into what should be expected on AMD?
This can make a huge difference for certain kinds of models (specially those with large matmuls). That’s why we check and emmit that warning by default
But it also isn’t relevant if you use a jax/numba sampler right?
Right, sorry if I missed that was the use case
It wasn’t mentioned in the OP, but if getting BLAS set up correct is totally off the table (due to permissions or w/e), maybe just using an alternative sampler would be good work-around.
I’m actually using the numpyro sampler for some of the models I’m testing - some don’t have jax wrappers so I need to use the pymc sampler.
Thank you all for the help!
on my apple silicon machine I’m getting no libblas=*=*accelerate*
, but I do get…
libblas 3.9.0 22_osxarm64_openblas conda-forge
I assume that’s kosher?
Yeah I believe this is the accelerate blas (it’s what I get on my mac as well)
Edit: They’re not the same but should be comparable, see here