Model initialized with ADVI, still incredibly slow sampling

Hi, I’m running a model of binary choice (yes/no) in a decision-making task. For most versions of the model, I can run sampling in about 1.5 hours after initializing with ADVI, and results look good.

However, when I add a power parameter (from prospect/utility theory, to control devaluation of reward curvature), the same model suddenly takes over 10 hours to run on average. For comparison, adding a similar parameter to Stan in R doesn’t really take longer to sample than the model without the prospect theory power parameter.

I tried profiling and I think I’ve incorporated the tips for setting Theano Floatx = Float32, and for setting amdlibm to True in the config (but maybe I’m not doing that correctly; I’m doing it through os before importing theano). When I tried profiling the gradient specifically, I received a traceback error that gradient was not defined.

NameError: name ‘gradient’ is not defined

I’m guessing the speed of sampling problem has to do with Theano and the way I have set up the power parameter, but I don’t know how to speed it up. Any help on this would be greatly appreciated! Model below.


SVFP_model_beta_scale_base_fatigue_prospect_profile = pm.Model()

with SVFP_model_beta_scale_base_fatigue_prospect_profile:
n_ptps = len(PtpList)

# Hyperiors for parameters
mu_beta_rew    = pm.HalfNormal('mu_beta_rew', sd=10.0)
sigma_beta_rew = pm.HalfNormal('sigma_beta_rew', sd=10.0)

mu_beta_threat    = pm.HalfNormal('mu_beta_threat', sd=10.0)
sigma_beta_threat = pm.HalfNormal('sigma_beta_threat', sd=10.0)

mu_scale_rew    = pm.StudentT('mu_scale_rew', nu=3, mu=0., sd=10.0)
sigma_scale_rew = pm.HalfNormal('sigma_scale_rew', sd=10.0)

mu_scale_threat    = pm.StudentT('mu_scale_threat', nu=3, mu=0., sd=10.0)
sigma_scale_threat = pm.HalfNormal('sigma_scale_threat', sd=10.0)

mu_baseline_rew    = pm.StudentT('mu_baseline_rew', nu=3, mu=0., sd=10.0)
sigma_baseline_rew = pm.HalfNormal('sigma_baseline_rew', sd=10.0)

mu_baseline_threat    = pm.StudentT('mu_baseline_threat', nu=3, mu=0., sd=10.0)
sigma_baseline_threat = pm.HalfNormal('sigma_baseline_threat', sd=10.0)

mu_fatigue_rew    = pm.StudentT('mu_fatigue_rew', nu=3, mu=0., sd=10.0)
sigma_fatigue_rew = pm.HalfNormal('sigma_fatigue_rew', sd=10.0)

mu_fatigue_threat    = pm.StudentT('mu_fatigue_threat', nu=3, mu=0., sd=10.0)
sigma_fatigue_threat = pm.HalfNormal('sigma_fatigue_threat', sd=10.0)

mu_prospect_rew    = pm.HalfNormal('mu_prospect_rew', sd=10.0)
sigma_prospect_rew = pm.HalfNormal('sigma_prospect_rew', sd=10.0)

mu_prospect_threat    = pm.HalfNormal('mu_prospect_threat', sd=10.0)
sigma_prospect_threat = pm.HalfNormal('sigma_prospect_threat', sd=10.0)


#Priors for parameters
beta_rew = pm.Normal('beta_rew', mu=mu_beta_rew, sd=sigma_beta_rew, shape=n_ptps)
beta_threat = pm.Normal('beta_threat', mu=mu_beta_threat, sd=sigma_beta_threat, shape=n_ptps)
scale_rew = pm.Normal('scale_rew', mu=mu_scale_rew, sd=sigma_scale_rew, shape=n_ptps)
scale_threat = pm.Normal('scale_threat', mu=mu_scale_threat, sd=sigma_scale_threat, shape=n_ptps)
baseline_rew = pm.Normal('baseline_rew', mu=mu_baseline_rew, sd=sigma_baseline_rew, shape=n_ptps)
baseline_threat = pm.Normal('baseline_threat', mu=mu_baseline_threat, sd=sigma_baseline_threat, shape=n_ptps)
fatigue_rew = pm.Normal('fatigue_rew', mu=mu_fatigue_rew, sd=sigma_fatigue_rew, shape=n_ptps)
fatigue_threat = pm.Normal('fatigue_threat', mu=mu_fatigue_threat, sd=sigma_fatigue_threat, shape=n_ptps)


prospect_tilde_rew = pm.Normal('prospect_t_rew', mu=0, sd=10, shape=n_ptps)
prospect_rew = pm.Deterministic('prospect_rew', mu_prospect_rew + sigma_prospect_rew*prospect_tilde_rew)

prospect_tilde_threat = pm.Normal('prospect_t_threat', mu=0, sd=10, shape=n_ptps)
prospect_threat = pm.Deterministic('prospect_threat', mu_prospect_threat + sigma_prospect_threat*prospect_tilde_threat)


stakelevel_prospect_rew = pm.Deterministic('stakelevel_prospect_rew', tt.pow(data_model.stakelevel.to_list(), prospect_rew[ptp_idx]))
stakelevel_prospect_threat = pm.Deterministic('stakelevel_prospect_threat', tt.pow(data_model.stakelevel.to_list(), prospect_threat[ptp_idx]))

# Calculate predictions given values
#rewardscenario indicator is 0 for threat trials (else 1)
#threatscenario indicator is 0 for reward trials (else 1)
#So one invlogit term is zero for every trial according to scenario


yhat = pm.invlogit((beta_rew[ptp_idx] * data_model.rewardscenario_ind * stakelevel_prospect_rew)-
                   (beta_rew[ptp_idx] * data_model.rewardscenario_ind * scale_rew[ptp_idx] * data_model.effort_threshold2) - 
                   (beta_rew[ptp_idx] * data_model.rewardscenario_ind * baseline_rew[ptp_idx]) -
                   (beta_rew[ptp_idx] * data_model.rewardscenario_ind * fatigue_rew[ptp_idx] * data_model.trial_number_scaled) +
                   (beta_threat[ptp_idx] * data_model.threatscenario_ind * stakelevel_prospect_threat)-
                   (beta_threat[ptp_idx] * data_model.threatscenario_ind * scale_threat[ptp_idx] * data_model.effort_threshold2) - 
                   (beta_threat[ptp_idx] * data_model.threatscenario_ind * baseline_threat[ptp_idx]) -
                   (beta_threat[ptp_idx] * data_model.threatscenario_ind * fatigue_threat[ptp_idx]*data_model.trial_number_scaled))



y = pm.Binomial('y', n=np.ones(data_model.shape[0]), p=yhat,
                observed=data_model.choice)

And sampling done with:
with SVFP_model_beta_scale_base_fatigue_prospect_profile:
SVFP_trace_beta_scale_base_fatigue_prospect_profile= pm.sample(2000, tune=2000,
init=‘advi’, chains=2,
nuts_kwargs={“target_accept”:0.99,
“max_treedepth”: 15})