When using pm.sampling_jax.sample_numpyro_nuts
on Colab with GPU, my model drops to <2 it/s, but using the same with CPU runs ~10it/s.
Has anyone else noticed this?
When using pm.sampling_jax.sample_numpyro_nuts
on Colab with GPU, my model drops to <2 it/s, but using the same with CPU runs ~10it/s.
Has anyone else noticed this?
Does your model have loop (either python loop or scan)? That’s usually slow in JAX and slower in GPU.
I don’t think so:
def build_model(velocity_scaled,
load_scaled,
session_exercise_id,
session_id,
coords,
render_model = True):
with pm.Model() as model:
# Add coordinates
model.add_coord('observation', coords['observation'], mutable = True)
model.add_coord('exercise', coords['exercise'], mutable = True)
model.add_coord('session', coords['session'], mutable = True)
# Add inputs
velocity_shared = pm.MutableData('velocity_scaled', velocity_scaled, dims = 'observation')
session_exercise_id = pm.MutableData('session_exercise_id', session_exercise_id, dims = 'session')
session_id = pm.MutableData('session_id', session_id, dims = 'observation')
# Global Parameters
intercept_global = pm.Normal(name = 'intercept_global',
mu = 0.0,
sigma = 1.0)
intercept_sigma_global = pm.HalfNormal(name = 'intercept_sigma_global',
sigma = 1.0)
slope_global = pm.HalfNormal(name = 'slope_global',
sigma = 1.0)
curve_global = pm.HalfNormal(name = 'curve_global',
sigma = 3.0)
error_global = pm.HalfNormal(name = 'error_global',
sigma = 1.0)
# Exercise Parameters
intercept_offset_exercise = pm.Normal(name = 'intercept_offset_exercise',
mu = 0.0,
sigma = 1.0,
dims = 'exercise')
intercept_exercise = pm.Deterministic(name = 'intercept_exercise',
var = intercept_global + intercept_sigma_global*intercept_offset_exercise,
dims = 'exercise')
intercept_sigma_exercise = pm.HalfNormal(name = 'intercept_sigma_exercise',
sigma = 1.0,
dims = 'exercise')
slope_exercise = pm.HalfNormal(name = 'slope_exercise',
sigma = slope_global,
dims = 'exercise')
curve_exercise = pm.HalfNormal(name = 'curve_exercise',
sigma = curve_global,
dims = 'exercise')
error_exercise = pm.HalfNormal(name = 'error_exercise',
sigma = error_global,
dims = 'exercise')
# Session Parameters
intercept_offset_session = pm.Normal(name = 'intercept_offset_session',
mu = 0.0,
sigma = 1.0,
dims = 'session')
intercept_session = pm.Deterministic(name = 'intercept_session',
var = (intercept_exercise[session_exercise_id]
+ intercept_sigma_exercise[session_exercise_id]
* intercept_offset_session),
dims = 'session')
slope_session = pm.HalfNormal(name = 'slope_session',
sigma = slope_exercise[session_exercise_id],
dims = 'session')
curve_session = pm.HalfNormal(name = 'curve_session',
sigma = curve_exercise[session_exercise_id],
dims = 'session')
error_session = pm.HalfNormal(name = 'error_session',
sigma = error_exercise[session_exercise_id],
dims = 'session')
# Final Parameters
intercept = intercept_session[session_id]
slope = slope_session[session_id]
curve = curve_session[session_id]
error = error_session[session_id]
# Estimated Value
load_mu = pm.Deterministic(name = 'load_mu',
var = intercept - slope*velocity_shared - curve*velocity_shared**2,
dims = 'observation')
# Likelihood
load_likelihood = pm.Normal(name = 'load_estimate',
mu = load_mu,
sigma = error,
observed = load_scaled,
dims = 'observation')
if render_model:
display(pm.model_to_graphviz(model))
return model
I setup my own local environment and face the same issue.
Is there something inherently wrong with my model?
I’ve also tried by making sure that all the data are either float32 or int32