GPU utilization is high but memory usage is very low leading to subpar sampling performance

I’m trying to run a simple pymc3 example using a desktop GPU but the performance is awful compared to my laptop cpu. Running watch nvidia-smi shows that my GPU utilization is very high ~94% but the memory usage is only about 10% of available memory.

For context, using the same model, initializing NUTS on my laptop achieves ~4500 it/s while on my GPU enabled desktop it caps out at ~65 it/s. Any ideas what could cause this?

To add more context here is an example:

n_unique_ids = 2500
N = 5000

ids = np.repeat(np.random.randint(0,n_unique_ids, n_unique_ids), 2)
churn = (np.random.randint(0,2,N)) # 0, 1 class

with pm.Model() as bb_model:
    p = pm.Beta('p', 1., 1., shape=n_unique_ids)
    
    y = pm.Bernoulli('y_obs', p[ids], observed=churn)  
    bb_trace = pm.sample(3000, tune=1000)

##  NUTS initialization ~65 it/s

I was able to improve the performance by converting ids and churn to shared variables but it is still not matching my laptop CPU performance:

n_unique_ids = 2500
N = 5000

ids = np.repeat(np.random.randint(0,n_unique_ids, n_unique_ids), 2)
churn = (np.random.randint(0,2,N)) # 0, 1 class

shared_ids = shared(ids)
shared_churn = shared(churn)

with pm.Model() as bb_model:
    p = pm.Beta('p', 1., 1., shape=n_unique_ids)
    
    y = pm.Bernoulli('y_obs', p[shared_ids], observed=shared_churn)  
    bb_trace = pm.sample(3000, tune=1000)

## NUTS initialization ~1550 it/s

With shared variables GPU utilization is also much lower ~41% but the amount of memory used is still ~10% of memory available.

This isn’t surprising, this just isn’t a problem that would run faster on the gpu. NUTS itself is running on the cpu, so it has to copy data back and forth between gpu memory and ram. Using the gpu is only worth it if you need to do a lot of work for each gradient evaluation. For example a cholesky might be a good candidate.

Would this model be a good candidate for GPU speed up:

In practice, these models will be ran on huge datasets.