Slow sampling with pm.sampling_jax.sample_numpyro_nuts() on Vertex AI

Hi, I am using M2 Macbook air with 16GB. The sampling with jax backend on CPU is very smooth. However, my model is rather complicated with great number of parameters so it takes about 30 min. to sample. Because of that, I tried to switch to Vertex AI workbench notebook on GCP. The machine specs are: Compute Optimized: 16 vCPUs, 64 GB RAM.

As I have surprisingly found out, the GCP sampling is about 10 times slower. When sampling with 1000 samples using parallel sampling, the process takes more than. 15 minutes on 4 chains. Locally, this is computed under a single minute.

Any idea why is that? I just can’t keep sampling the models locally, and at some point the computing will be moved to the cloud anyway.

Thank you all for your suggestions.

2 Likes

Bump.

I have the same issue.

Fairly simple model but using nested multilevel modeling with ragged data.

32 vCPUs and 120 GB RAM
100k samples
Global β†’ Second level β†’ Third level β†’ Output
One predictor only

It takes about 30-35 minutes with 1000 draws, 1000 samples, target prob 0.95 on sample_numpyro_nuts() with default settings.