Why does it seem that StudentT distributions are slower to sample compared to the Normal Distribution

Hello,

Is it normal to take a model longer to sample with the likelihood following a Student T distribution compared to a normal distribution or could it be my model?

Below is my model. When switching my likelihood from normal to Student T, the sampling times seem to double:

with pm.Model(coords=coords) as constant_model:    
    #Data that does not change
    cat_to_bl_map = pm.Data('cat_to_bl_map', cat_to_bl_idx, mutable=False)
    subcat_to_cat_map = pm.Data('subcat_to_cat_map', subcat_to_cat_idx, mutable=False)
    ic_to_subcat_map = pm.Data('ic_to_subcat_map', ic_to_subcat_idx, mutable=False)
    ic_to_item_map = pm.Data('ic_to_item_map', ic_to_item_idx, mutable = False)

    #Data that does change
    pm_loc_idx = pm.Data('loc_idx', location_idx, mutable = True)
    pm_item_idx = pm.Data('item_idx', item_idx, mutable=True)
    pm_time_idx = pm.Data('time_idx', time_idx, mutable=True)
    observed_eaches = pm.Data('observed_eaches', df_train.residual, mutable=True)
    dc_discount_ = pm.Data('dc_discount', dc_idx, mutable = True)
    month_8_ = pm.Data('month_8', month_8_idx ,mutable = True)
    month_9_ = pm.Data('month_9', month_9_idx ,mutable = True)
    count_thrity_ = pm.Data('count_30', count_thirty_idx, mutable = True)
    price_change_before_ = pm.Data('price_change_before', price_change_before_idx, mutable = True)
    price_change_on_ = pm.Data('price_change_on', df_train['price_change_on_flag'].values, mutable = True)
                                  
    loc_intercept = pm.Normal('loc_intercept', mu = 0, sigma = 1, dims = ['location'])
    loc_bl = utility_functions.make_next_level_hierarchy_variable(name='loc_bl', mu=loc_intercept, alpha=2, beta=1, dims=['business_line', 'location'])
    loc_cat = utility_functions.make_next_level_hierarchy_variable(name='loc_cat', mu=loc_bl[cat_to_bl_map], alpha=2, beta=1, dims=['category', 'location'])
    loc_subcat = utility_functions.make_next_level_hierarchy_variable(name='loc_subcat', mu=loc_cat[subcat_to_cat_map], alpha=2, beta=1, dims=['subcategory', 'location'])
    loc_ic = utility_functions.make_next_level_hierarchy_variable(name='loc_ic', mu=loc_subcat[ic_to_subcat_map], alpha=2, beta=1, dims=['ic', 'location'])
    loc_item = utility_functions.make_next_level_hierarchy_variable(name='loc_item', mu=loc_ic[ic_to_item_map], alpha=2, beta=1, dims=['item', 'location'])

    mu_dc_discount = pm.Normal('mu_dc_discount', mu = 0, sigma = 1)
    bl_dc_discount = utility_functions.make_next_level_hierarchy_variable(name='bl_dc_discount', mu=mu_dc_discount, alpha=2, beta=1, dims=['business_line'])
    cat_dc_discount = utility_functions.make_next_level_hierarchy_variable(name='cat_dc_discount', mu=bl_dc_discount[cat_to_bl_map], alpha=2, beta=1,  dims=['category'])
    subcat_dc_discount = utility_functions.make_next_level_hierarchy_variable(name='subcat_dc_discount', mu=cat_dc_discount[subcat_to_cat_map],  alpha=2, beta=1, dims=['subcategory'])
    ic_dc_discount = utility_functions.make_next_level_hierarchy_variable(name='ic_dc_discount', mu=subcat_dc_discount[ic_to_subcat_map],  alpha=2, beta=1, dims=['ic'])
    item_dc_discount = utility_functions.make_next_level_hierarchy_variable(name='item_dc_discount', mu=ic_dc_discount[ic_to_item_map], alpha=2, beta=1,  dims=['item'])

    mu_month8 = pm.Normal('mu_month8', mu = 0, sigma = 1)
    bl_month8 = utility_functions.make_next_level_hierarchy_variable(name='bl_month8', mu=mu_month8, alpha=2, beta=1, dims=['business_line'])
    cat_month8 = utility_functions.make_next_level_hierarchy_variable(name='cat_month8', mu=bl_month8[cat_to_bl_map], alpha=2, beta=1,  dims=['category'])
    subcat_month8 = utility_functions.make_next_level_hierarchy_variable(name='subcat_month8', mu=cat_month8[subcat_to_cat_map],  alpha=2, beta=1, dims=['subcategory'])
    ic_month8 = utility_functions.make_next_level_hierarchy_variable(name='ic_month8', mu=subcat_month8[ic_to_subcat_map],  alpha=2, beta=1, dims=['ic'])
    item_month8 = utility_functions.make_next_level_hierarchy_variable(name='item_month8', mu=ic_month8[ic_to_item_map], alpha=2, beta=1,  dims=['item'])

    mu_month9 = pm.Normal('mu_month9', mu = 0, sigma = 1)
    bl_month9 = utility_functions.make_next_level_hierarchy_variable(name='bl_month9', mu=mu_month9, alpha=2, beta=1, dims=['business_line'])
    cat_month9 = utility_functions.make_next_level_hierarchy_variable(name='cat_month9', mu=bl_month9[cat_to_bl_map], alpha=2, beta=1,  dims=['category'])
    subcat_month9 = utility_functions.make_next_level_hierarchy_variable(name='subcat_month9', mu=cat_month9[subcat_to_cat_map],  alpha=2, beta=1, dims=['subcategory'])
    ic_month9 = utility_functions.make_next_level_hierarchy_variable(name='ic_month9', mu=subcat_month9[ic_to_subcat_map],  alpha=2, beta=1, dims=['ic'])
    item_month9 = utility_functions.make_next_level_hierarchy_variable(name='item_month9', mu=ic_month9[ic_to_item_map], alpha=2, beta=1,  dims=['item'])

    mu_count30 = pm.Normal('mu_count30', mu = 0, sigma = 1)
    bl_count30 = utility_functions.make_next_level_hierarchy_variable(name='bl_count30', mu=mu_count30, alpha=2, beta=1, dims=['business_line'])
    cat_count30 = utility_functions.make_next_level_hierarchy_variable(name='cat_count30', mu=bl_count30[cat_to_bl_map], alpha=2, beta=1,  dims=['category'])
    subcat_count30 = utility_functions.make_next_level_hierarchy_variable(name='subcat_count30', mu=cat_count30[subcat_to_cat_map],  alpha=2, beta=1, dims=['subcategory'])
    ic_count30 = utility_functions.make_next_level_hierarchy_variable(name='ic_count30', mu=subcat_count30[ic_to_subcat_map],  alpha=2, beta=1, dims=['ic'])
    item_count30 = utility_functions.make_next_level_hierarchy_variable(name='item_count30', mu=ic_count30[ic_to_item_map], alpha=2, beta=1,  dims=['item'])
    
    mu_price_change_before = pm.Normal('mu_price_change_before', mu = 0, sigma = 1)
    bl_price_change_before = utility_functions.make_next_level_hierarchy_variable(name='bl_price_change_before', mu=mu_price_change_before, alpha=2, beta=1, dims=['business_line'])
    cat_price_change_before = utility_functions.make_next_level_hierarchy_variable(name='cat_price_change_before', mu=bl_price_change_before[cat_to_bl_map], alpha=2, beta=1,  dims=['category'])
    subcat_price_change_before = utility_functions.make_next_level_hierarchy_variable(name='subcat_price_change_before', mu=cat_price_change_before[subcat_to_cat_map],  alpha=2, beta=1, dims=['subcategory'])
    ic_price_change_before = utility_functions.make_next_level_hierarchy_variable(name='ic_price_change_before', mu=subcat_price_change_before[ic_to_subcat_map],  alpha=2, beta=1, dims=['ic'])
    item_price_change_before = utility_functions.make_next_level_hierarchy_variable(name='item_price_change_before', mu=ic_price_change_before[ic_to_item_map], alpha=2, beta=1,  dims=['item'])

    mu_price_change_on = pm.Normal('mu_price_change_on', mu = 0, sigma = 1)
    bl_price_change_on = utility_functions.make_next_level_hierarchy_variable(name='bl_price_change_on', mu=mu_price_change_on, alpha=2, beta=1, dims=['business_line'])
    cat_price_change_on = utility_functions.make_next_level_hierarchy_variable(name='cat_price_change_on', mu=bl_price_change_on[cat_to_bl_map], alpha=2, beta=1,  dims=['category'])
    subcat_price_change_on = utility_functions.make_next_level_hierarchy_variable(name='subcat_price_change_on', mu=cat_price_change_on[subcat_to_cat_map],  alpha=2, beta=1, dims=['subcategory'])
    ic_price_change_on = utility_functions.make_next_level_hierarchy_variable(name='ic_price_change_on', mu=subcat_price_change_on[ic_to_subcat_map],  alpha=2, beta=1, dims=['ic'])
    item_price_change_on = utility_functions.make_next_level_hierarchy_variable(name='item_price_change_on', mu=ic_price_change_on[ic_to_item_map], alpha=2, beta=1,  dims=['item'])

    mu = (loc_item[pm_item_idx, pm_loc_idx] 
          + item_month9[pm_item_idx]*month_9_ 
          + item_count30[pm_item_idx]*count_thrity_
          + item_dc_discount[pm_item_idx]*dc_discount_
          + item_price_change_before[pm_item_idx]*price_change_before_ 
          + item_price_change_on[pm_item_idx]*price_change_on_
          + item_month8[pm_item_idx]*month_8_ 
    )
    
    nu_ = pm.Gamma('nu', 2, 0.1) 
    
    pm.StudentT('predicted_eaches',
                         mu=mu,
                         nu= nu_, 
                         observed=observed_eaches,)

Could this be a problem with the model, or is the StudentT just that much slower to sample?

That’s not very surprising for me. The StudentT likelihood is considerably more complicated. The extra nu parameter may also slow things down if it’s not very well identified

2 Likes