Slow inference for numpyro sampling on Colab GPU

When using pm.sampling_jax.sample_numpyro_nuts on Colab with GPU, my model drops to <2 it/s, but using the same with CPU runs ~10it/s.

Has anyone else noticed this?

Does your model have loop (either python loop or scan)? That’s usually slow in JAX and slower in GPU.

I don’t think so:

def build_model(velocity_scaled,
                load_scaled,
                session_exercise_id,
                session_id,
                coords,
                render_model = True):
    
    with pm.Model() as model:

        # Add coordinates
        model.add_coord('observation', coords['observation'], mutable = True)
        model.add_coord('exercise', coords['exercise'], mutable = True)
        model.add_coord('session', coords['session'], mutable = True)

        # Add inputs
        velocity_shared = pm.MutableData('velocity_scaled', velocity_scaled, dims = 'observation')
        session_exercise_id = pm.MutableData('session_exercise_id', session_exercise_id, dims = 'session')
        session_id = pm.MutableData('session_id', session_id, dims = 'observation')

        # Global Parameters
        intercept_global = pm.Normal(name = 'intercept_global',
                                     mu = 0.0,
                                     sigma = 1.0)
        
        intercept_sigma_global = pm.HalfNormal(name = 'intercept_sigma_global',
                                               sigma = 1.0)
        
        slope_global = pm.HalfNormal(name = 'slope_global',
                                     sigma = 1.0)
        
        curve_global = pm.HalfNormal(name = 'curve_global',
                                     sigma = 3.0)
        
        error_global = pm.HalfNormal(name = 'error_global',
                                     sigma = 1.0)
        
        # Exercise Parameters
        intercept_offset_exercise = pm.Normal(name = 'intercept_offset_exercise',
                                              mu = 0.0,
                                              sigma = 1.0,
                                              dims = 'exercise')
        
        intercept_exercise = pm.Deterministic(name = 'intercept_exercise',
                                              var = intercept_global + intercept_sigma_global*intercept_offset_exercise,
                                              dims = 'exercise')
        
        intercept_sigma_exercise = pm.HalfNormal(name = 'intercept_sigma_exercise',
                                                 sigma = 1.0,
                                                 dims = 'exercise')
        
        slope_exercise = pm.HalfNormal(name = 'slope_exercise',
                                       sigma = slope_global,
                                       dims = 'exercise')
        
        curve_exercise = pm.HalfNormal(name = 'curve_exercise',
                                       sigma = curve_global,
                                       dims = 'exercise')
        
        error_exercise = pm.HalfNormal(name = 'error_exercise',
                                       sigma = error_global,
                                       dims = 'exercise')
        
        # Session Parameters
        intercept_offset_session = pm.Normal(name = 'intercept_offset_session',
                                             mu = 0.0,
                                             sigma = 1.0,
                                             dims = 'session')

        intercept_session = pm.Deterministic(name = 'intercept_session',
                                                var = (intercept_exercise[session_exercise_id]
                                                       + intercept_sigma_exercise[session_exercise_id]
                                                       * intercept_offset_session),
                                                dims = 'session')
        
        slope_session = pm.HalfNormal(name = 'slope_session',
                                      sigma = slope_exercise[session_exercise_id],
                                      dims = 'session')

        curve_session = pm.HalfNormal(name = 'curve_session',
                                      sigma = curve_exercise[session_exercise_id],
                                      dims = 'session')
        
        error_session = pm.HalfNormal(name = 'error_session',
                                      sigma = error_exercise[session_exercise_id],
                                      dims = 'session')

        # Final Parameters
        intercept = intercept_session[session_id]
        slope = slope_session[session_id]
        curve = curve_session[session_id]
        error = error_session[session_id]

        # Estimated Value
        load_mu = pm.Deterministic(name = 'load_mu',
                                   var = intercept - slope*velocity_shared - curve*velocity_shared**2,
                                   dims = 'observation')

        # Likelihood
        load_likelihood = pm.Normal(name = 'load_estimate',
                                    mu = load_mu,
                                    sigma = error,
                                    observed = load_scaled,
                                    dims = 'observation')
    
    if render_model:
        display(pm.model_to_graphviz(model))

    return model