GPU is running the model much slower than the CPU

Hello,

I have a GCP instance with 4 Tesla T4 GPUs. It is taking 4 hours to sample my data when running on CPU takes only 1 hour.

I found this… Gpu much slower than cpu - Questions - PyMC Discourse

But also read about the possible faster fitting with GPU and jax.

Here is my setup.

<module 'jax.version' from '/opt/conda/lib/python3.7/site-packages/jax/version.py'>
<module 'jaxlib.version' from '/opt/conda/lib/python3.7/site-packages/jaxlib/version.py'>
PyMC Version:  4.0.1 
 Aesara Version:  2.7.3 
 Arvize Verions:  0.12.1 Using  gpu  with  [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]

Here is my model for reference.

coords={"locations":locations,
            "items":items,
            'months':months,
            'changepoints':df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]],
            "yearly_components": [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)],
            "obs_id":[f'{loc}_{time.year}_month_{time.month}_item_{item}' for time, loc, item in df_train.index.values]}

with pm.Model(coords=coords) as model:
        
    A = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
    s = pm.Data('s', s, mutable=True, dims=['changepoints'])
    yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])

    # Slope
    mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
    sigma_loc_slope = pm.HalfNormal('sigma_loc_slope', sigma=0.1)
    sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
    offset_loc_slope = pm.Normal('offset_loc_slope', mu=0, sigma=0.1, dims=['locations'])
    offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])

    loc_slope = sigma_loc_slope * offset_loc_slope
    item_slope = sigma_item_slope * offset_item_slope
    initial_slope = pm.Deterministic('initial_slope', mu_slope + loc_slope[location_idxs] + item_slope[item_idxs],
                                     dims=['obs_id'])

    # Intercept
    mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
    sigma_loc_intercept = pm.HalfNormal('sigma_loc_intercept', sigma=0.1)
    sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
    offset_loc_intercept = pm.Normal('offset_loc_intercept', mu=0, sigma=0.1, dims=['locations'])
    offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])

    loc_intercept = sigma_loc_intercept * offset_loc_intercept
    item_intercept = sigma_item_intercept * offset_item_intercept
    initial_intercept = pm.Deterministic('initial_intercept', mu_intercept + loc_intercept[location_idxs] + item_intercept[item_idxs],
                                         dims=['obs_id'])
    # Offsets
    mu_delta = pm.Normal('mu_delta', 0, 0.1)
    sigma_loc_delta = pm.HalfNormal('sigma_loc_delta', sigma=0.1)
    sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
    offset_loc_delta = pm.Normal('offset_loc_delta', mu=0, sigma=0.25, dims=['locations', 'changepoints'])
    offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])

    loc_delta = sigma_loc_delta * offset_loc_delta
    item_delta = sigma_item_delta * offset_item_delta
    delta = pm.Deterministic('delta', mu_delta + loc_delta[location_idxs, :] + item_delta[item_idxs, :], dims=['obs_id', 'changepoints'])
    
    #monthly seasonality
    yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
    yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
    yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['locations', 'yearly_components'])
    yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs] * yearly_beta[location_idxs, :]).sum(axis=1), dims=['obs_id'])

    # Monthly Effects
    beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])

    intercept = initial_intercept + ((-s * A)[time_idxs, :] * delta).sum(axis=1)
    slope = initial_slope + (A[time_idxs, :] * delta).sum(axis=1)

    mu = pm.Deterministic('mu', intercept + slope * t + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
    likelihood = pm.Poisson('predicted_eaches',
                            mu=np.exp(mu),
                            observed=y,
                            dims=['obs_id'],)
    
    prior = pm.sample_prior_predictive()
    trace = pymc.sampling_jax.sample_numpyro_nuts(tune=2000, draws = 1000)

Am I doing something wrong that bogs down the GPU more so than the CPUs?

1 Like

I’m no expert but assuming you followed these steps: Set up JAX sampling with GPUs in PyMC v4 · pymc-devs/pymc Wiki · GitHub, are you on CUDA 11.4, cuDNN v8.2.4? What does nvcc --version show?

1 Like

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_Oct_11_21:27:02_PDT_2021
Cuda compilation tools, release 11.4, V11.4.152
Build cuda_11.4.r11.4/compiler.30521435_0

I’m really not sure this would fix the main issue, but would changing the np.exp(mu) to the Aesara (or pymc.math.exp()) function help?

Hi Jordan. Try adding chain_method="vectorized" as an argument to sample_numpyro_nuts. This should help.

Thank you. I added the argument and the iteration speed went from 7.56 s/it to 1.56 s/it. Good improvement.

UPDATE

At the end of the sampling, I got an out of memory error which I don’t get with the CPUs.

022-07-27 12:46:57.189710: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.85GiB (rounded to 9501120000)requested by op 
2022-07-27 12:46:57.189959: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ****************____________________________________________________________________________________
2022-07-27 12:46:57.190041: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9501120000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    4.69MiB
              constant allocation:         0B
        maybe_live_out allocation:    8.85GiB
     preallocated temp allocation:         0B
                 total allocation:    8.85GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 8.85GiB
		Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=606
		XLA Label: gather
		Shape: f64[4,1000,29691,10]
		==========================

	Buffer 2:
		Size: 4.58MiB
		Entry Parameter Subshape: f64[4,1000,15,10]
		==========================

	Buffer 3:
		Size: 116.0KiB
		Entry Parameter Subshape: s32[29691,1]
		==========================


---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
/tmp/ipykernel_3823/574895807.py in <module>
     82 
     83     # prior = pm.sample_prior_predictive()
---> 84     trace = pymc.sampling_jax.sample_numpyro_nuts(tune=2000, draws = 1000, chain_method="vectorized")
     85 
     86 print(datetime.now())

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    534     jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    535     result = jax.vmap(jax.vmap(jax_fn))(
--> 536         *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
    537     )
    538     mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

    [... skipping hidden 6 frame]

/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in jax_funcified_fgraph(mu_slope, sigma_loc_slope_log_, sigma_item_slope_log_, offset_loc_slope, offset_item_slope, mu_intercept, sigma_loc_intercept_log_, sigma_item_intercept_log_, offset_loc_intercept, offset_item_intercept, mu_delta, sigma_loc_delta_log_, sigma_item_delta_log_, offset_loc_delta, offset_item_delta, yearly_mu, yearly_sigma_log_, yearly_beta, beta_month)
      4     auto_62624 = subtensor(beta_month, auto_2773)
      5     # AdvancedSubtensor1(yearly_beta, TensorConstant{[ 0  0  0 .. 14 14 14]})
----> 6     auto_62625 = subtensor1(yearly_beta, auto_1666)
      7     # Elemwise{exp,no_inplace}(sigma_item_slope_log__)
      8     sigma_item_slope = exp(sigma_item_slope_log_)

/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py in subtensor(x, *ilists)
    604             indices = indices[0]
    605 
--> 606         return x.__getitem__(indices)
    607 
    608     return subtensor

    [... skipping hidden 1 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   3569   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   3570   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
-> 3571                  unique_indices, mode, fill_value)
   3572 
   3573 # TODO(phawkins): re-enable jit after fixing excessive recompilation for

/opt/conda/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   3599       unique_indices=unique_indices or indexer.unique_indices,
   3600       indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
-> 3601       mode=mode, fill_value=fill_value)
   3602 
   3603   # Reverses axes with negative strides.

    [... skipping hidden 16 frame]

/opt/conda/lib/python3.7/site-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, *args)
    715     in_flat, token_handler = _add_tokens(has_unordered_effects, ordered_effects,
    716                                          device, in_flat)
--> 717   out_flat = compiled.execute(in_flat)
    718   check_special(name, out_flat)
    719   out_bufs = unflatten(out_flat, output_buffer_counts)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9501120000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    4.69MiB
              constant allocation:         0B
        maybe_live_out allocation:    8.85GiB
     preallocated temp allocation:         0B
                 total allocation:    8.85GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 8.85GiB
		Operator: op_name="jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 3), collapsed_slice_dims=(2,), start_index_map=(2,)) slice_sizes=(4, 1000, 1, 10) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/opt/conda/lib/python3.7/site-packages/aesara/link/jax/dispatch.py" source_line=606
		XLA Label: gather
		Shape: f64[4,1000,29691,10]
		==========================

	Buffer 2:
		Size: 4.58MiB
		Entry Parameter Subshape: f64[4,1000,15,10]
		==========================

	Buffer 3:
		Size: 116.0KiB
		Entry Parameter Subshape: s32[29691,1]
		==========================

I cut the observations by sampling half of the data. I think in this case I need to stick with the CPU since I can’t get a fit model with these GPUs and their memory constraints.

Thank you. This with the chain method seemed to help.

In case JAX is preallocating 75% of the GPU memory by default and resulting in an OOM error where your data might otherwise not actually use up that space, this might help:
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Other configs are available here: GPU memory allocation — JAX documentation