Slow sampling with pm.sampling_jax.sample_numpyro_nuts() on Vertex AI

Hi, I am using M2 Macbook air with 16GB. The sampling with jax backend on CPU is very smooth. However, my model is rather complicated with great number of parameters so it takes about 30 min. to sample. Because of that, I tried to switch to Vertex AI workbench notebook on GCP. The machine specs are: Compute Optimized: 16 vCPUs, 64 GB RAM.

As I have surprisingly found out, the GCP sampling is about 10 times slower. When sampling with 1000 samples using parallel sampling, the process takes more than. 15 minutes on 4 chains. Locally, this is computed under a single minute.

Any idea why is that? I just can’t keep sampling the models locally, and at some point the computing will be moved to the cloud anyway.

Thank you all for your suggestions.

1 Like