Slow inference for numpyro sampling on Colab GPU

When using pm.sampling_jax.sample_numpyro_nuts on Colab with GPU, my model drops to <2 it/s, but using the same with CPU runs ~10it/s.

Has anyone else noticed this?

Does your model have loop (either python loop or scan)? That’s usually slow in JAX and slower in GPU.

I don’t think so:

def build_model(velocity_scaled,
                load_scaled,
                session_exercise_id,
                session_id,
                coords,
                render_model = True):
    
    with pm.Model() as model:

        # Add coordinates
        model.add_coord('observation', coords['observation'], mutable = True)
        model.add_coord('exercise', coords['exercise'], mutable = True)
        model.add_coord('session', coords['session'], mutable = True)

        # Add inputs
        velocity_shared = pm.MutableData('velocity_scaled', velocity_scaled, dims = 'observation')
        session_exercise_id = pm.MutableData('session_exercise_id', session_exercise_id, dims = 'session')
        session_id = pm.MutableData('session_id', session_id, dims = 'observation')

        # Global Parameters
        intercept_global = pm.Normal(name = 'intercept_global',
                                     mu = 0.0,
                                     sigma = 1.0)
        
        intercept_sigma_global = pm.HalfNormal(name = 'intercept_sigma_global',
                                               sigma = 1.0)
        
        slope_global = pm.HalfNormal(name = 'slope_global',
                                     sigma = 1.0)
        
        curve_global = pm.HalfNormal(name = 'curve_global',
                                     sigma = 3.0)
        
        error_global = pm.HalfNormal(name = 'error_global',
                                     sigma = 1.0)
        
        # Exercise Parameters
        intercept_offset_exercise = pm.Normal(name = 'intercept_offset_exercise',
                                              mu = 0.0,
                                              sigma = 1.0,
                                              dims = 'exercise')
        
        intercept_exercise = pm.Deterministic(name = 'intercept_exercise',
                                              var = intercept_global + intercept_sigma_global*intercept_offset_exercise,
                                              dims = 'exercise')
        
        intercept_sigma_exercise = pm.HalfNormal(name = 'intercept_sigma_exercise',
                                                 sigma = 1.0,
                                                 dims = 'exercise')
        
        slope_exercise = pm.HalfNormal(name = 'slope_exercise',
                                       sigma = slope_global,
                                       dims = 'exercise')
        
        curve_exercise = pm.HalfNormal(name = 'curve_exercise',
                                       sigma = curve_global,
                                       dims = 'exercise')
        
        error_exercise = pm.HalfNormal(name = 'error_exercise',
                                       sigma = error_global,
                                       dims = 'exercise')
        
        # Session Parameters
        intercept_offset_session = pm.Normal(name = 'intercept_offset_session',
                                             mu = 0.0,
                                             sigma = 1.0,
                                             dims = 'session')

        intercept_session = pm.Deterministic(name = 'intercept_session',
                                                var = (intercept_exercise[session_exercise_id]
                                                       + intercept_sigma_exercise[session_exercise_id]
                                                       * intercept_offset_session),
                                                dims = 'session')
        
        slope_session = pm.HalfNormal(name = 'slope_session',
                                      sigma = slope_exercise[session_exercise_id],
                                      dims = 'session')

        curve_session = pm.HalfNormal(name = 'curve_session',
                                      sigma = curve_exercise[session_exercise_id],
                                      dims = 'session')
        
        error_session = pm.HalfNormal(name = 'error_session',
                                      sigma = error_exercise[session_exercise_id],
                                      dims = 'session')

        # Final Parameters
        intercept = intercept_session[session_id]
        slope = slope_session[session_id]
        curve = curve_session[session_id]
        error = error_session[session_id]

        # Estimated Value
        load_mu = pm.Deterministic(name = 'load_mu',
                                   var = intercept - slope*velocity_shared - curve*velocity_shared**2,
                                   dims = 'observation')

        # Likelihood
        load_likelihood = pm.Normal(name = 'load_estimate',
                                    mu = load_mu,
                                    sigma = error,
                                    observed = load_scaled,
                                    dims = 'observation')
    
    if render_model:
        display(pm.model_to_graphviz(model))

    return model

I setup my own local environment and face the same issue.
Is there something inherently wrong with my model?

I’ve also tried by making sure that all the data are either float32 or int32