Are there instances when reparameterizing a model makes it slower?

I think I discovered the root of the overall performance problem. It’s related to this post here, as I had the data in long format: https://discourse.pymc.io/t/big-drop-in-sampling-performance-using-long-data-table-format

Changing it to use aggregated data, it runs in less than 20 seconds using the centered approach, but still takes nearly 3 times that for the non-centered version (see new models below).

Can you elaborate on this? Looking at the prior it generates once the model is run, it seems to be behaving just fine, even if it is a bit slow.

Not sure if I follow here either. The prior has a lot of weight around -4 because that’s what the data informs it to be. E.g. inv_logit(-4) = 0.018 or 1.8%, which is in the ballpark of what the expected conversion rate is.

Aggregated/Centered:

with pm.Model() as overall_rate_agg:
    all_groups_mean  = pm.Normal("all_groups_mean", 0, 1.5)
    all_groups_sigma = pm.Exponential("all_groups_sigma", 1.2)
    
    group_intercept = pm.Normal(
        "group_intercept", 
        all_groups_mean, 
        all_groups_sigma, 
        shape=len(groups)
    )
    
    prob_of_converting = pm.math.invlogit(group_intercept[analysis_agg['group_handle']])
    
    became_pos_user = pm.Binomial(
        "converted", 
        p=prob_of_converting,
        n=analysis_agg['total'],
        observed=analysis_agg['converted']
    )
    
    prior_overall_agg = pm.sample_prior_predictive()

Aggregated/Non-Centered:

with pm.Model() as overall_rate_agg_reparam:
    all_groups_mean  = pm.Normal("all_groups_mean", 0, 1.5)
    all_groups_sigma = pm.Exponential("all_groups_sigma", 1.2)
    z = pm.Normal("z", 0.0, 1.0, shape=len(groups))
    
    _ = pm.Deterministic(
        'group_intercept',
        all_groups_mean + z * all_groups_sigma
    )
    
    prob_of_converting = pm.math.invlogit(
        all_groups_mean + z[analysis_agg['group_handle']] * all_groups_sigma
    )
    
    became_pos_user = pm.Binomial(
        "converted", 
        p=prob_of_converting,
        n=analysis_agg['total'],
        observed=analysis_agg['converted']
    )
    
    prior_overall_agg_reparam = pm.sample_prior_predictive()