Scaling NUTS/PGBART sampler across 16 virtual CPUs on SageMaker

I’m trying to sample from a Bernoulli distribution using NUTS with a PGBART step. I’m using the BART model for a classification problem. The model takes ~40 minutes to sample 1 chain and 1k samples on my laptop, so I spun up an AWS SageMaker Studio instance thinking it would be faster. Surprisingly, the sampling was not faster on the instance, regardless of how large an instance I tried. I tested 16 virtual CPUs + 64GB RAM, 32 vCPU, 64, etc. When running on 16 CPUs, I noticed the instance CPU usage stayed right around 6%, so it looks like the sampler it’s only using 1 CPU.

Does anyone know how I can get the sampler to utilize all of the CPUs on the instance? Instance type is ml.m5.xlarge, running Python 3.7.

For NUTS/PGBART steps, I’m using:

n_samples = 1000
n_tune_samples = 800
target_accept_param = 0.8
n_particles = 10
max_stages = 150

Thanks!