Are there instances when reparameterizing a model makes it slower?

I’ve built a multi-level logistic regression that I’m using to test to see if there are differences between a treatment and control in an online experiment. There’s about 90k points being fed into the model, with only a single column indicating treatment/control being used as a predictor, and the single outcome variable specifying success/failure to reach a certain conversion point. Both columns are encoded as integers, and the expected conversion rate is quite low (1-2%).

When I run the centered model, it takes about 10 minutes to complete, and has a large number of divergences when I leave the target accept rate at around 0.8, and the sampling does not look very efficient based on the summary stats. I thought it would help to reparameterize this model, but when I do that, it takes nearly an hour to run, and the traces are much, much worse. I’ve tried adjusting both the target acceptance rate and the number of tuning steps, but neither seem to help, and I suspect there is a deeper problem I’m overlooking.

Any ideas on what might be causing this?

The centered model looks like this:

model_df = analysis_df
conversion_column = 'conversion_column'
group_column = 'group_column'

with pm.Model() as overall_rate:
    all_groups_mean  = pm.Normal("all_groups_mean", 0, 1.5)
    all_groups_sigma = pm.Exponential("all_groups_sigma", 1.2)
    
    group_intercept = pm.Normal(
        "group_intercept", 
        all_groups_mean, 
        all_groups_sigma, 
        shape=len(groups)
    )
    
    prob_of_converting = pm.math.invlogit(group_intercept[model_df[group_column]])
    
    became_pos_user = pm.Bernoulli(
        "converted", 
        prob_of_converting, 
        observed=model_df[conversion_column]
    )
    
    trace_overall = pm.sample(
        target_accept=0.95,
    )

And the reparameterized model:

with pm.Model() as overall_rate_reparam:
    all_groups_mean  = pm.Normal("all_groups_mean", 0.0, 1.5)
    all_groups_sigma = pm.Exponential("all_groups_sigma", 1.2)
    z = pm.Normal("z", 0.0, 1.0, shape=len(groups))
    
    _ = pm.Deterministic("group_intercept", all_groups_mean + z * all_groups_sigma)
    prob_of_converting = pm.math.invlogit(
        all_groups_mean + z[model_df[group_column]] * all_groups_sigma
    )
    
    became_pos_user = pm.Binomial(
        "converted",
        1,
        prob_of_converting, 
        observed=model_df[conversion_column]
    )
    
    prior_overall_reparam = pm.sample_prior_predictive()

How many groups are there, just two? Because two is too few to use a hierarchical distribution.

My guess is the centered model is running slowly because all_groups_mean is given low prior likelihood of a value that extreme (-4).

Try modeling as an intercept + dummy for treatment, and make set the prior for the intercept to be something like pm.Normal('intercept', 0, 4)

I think I discovered the root of the overall performance problem. It’s related to this post here, as I had the data in long format: https://discourse.pymc.io/t/big-drop-in-sampling-performance-using-long-data-table-format

Changing it to use aggregated data, it runs in less than 20 seconds using the centered approach, but still takes nearly 3 times that for the non-centered version (see new models below).

Can you elaborate on this? Looking at the prior it generates once the model is run, it seems to be behaving just fine, even if it is a bit slow.

Not sure if I follow here either. The prior has a lot of weight around -4 because that’s what the data informs it to be. E.g. inv_logit(-4) = 0.018 or 1.8%, which is in the ballpark of what the expected conversion rate is.

Aggregated/Centered:

with pm.Model() as overall_rate_agg:
    all_groups_mean  = pm.Normal("all_groups_mean", 0, 1.5)
    all_groups_sigma = pm.Exponential("all_groups_sigma", 1.2)
    
    group_intercept = pm.Normal(
        "group_intercept", 
        all_groups_mean, 
        all_groups_sigma, 
        shape=len(groups)
    )
    
    prob_of_converting = pm.math.invlogit(group_intercept[analysis_agg['group_handle']])
    
    became_pos_user = pm.Binomial(
        "converted", 
        p=prob_of_converting,
        n=analysis_agg['total'],
        observed=analysis_agg['converted']
    )
    
    prior_overall_agg = pm.sample_prior_predictive()

Aggregated/Non-Centered:

with pm.Model() as overall_rate_agg_reparam:
    all_groups_mean  = pm.Normal("all_groups_mean", 0, 1.5)
    all_groups_sigma = pm.Exponential("all_groups_sigma", 1.2)
    z = pm.Normal("z", 0.0, 1.0, shape=len(groups))
    
    _ = pm.Deterministic(
        'group_intercept',
        all_groups_mean + z * all_groups_sigma
    )
    
    prob_of_converting = pm.math.invlogit(
        all_groups_mean + z[analysis_agg['group_handle']] * all_groups_sigma
    )
    
    became_pos_user = pm.Binomial(
        "converted", 
        p=prob_of_converting,
        n=analysis_agg['total'],
        observed=analysis_agg['converted']
    )
    
    prior_overall_agg_reparam = pm.sample_prior_predictive()

It’s hard to infer the mean and standard deviation of a distribution if you only have two draws from it!

The posterior has a lot of weight around -4 because that’s what the data informs it to be. The prior, with a mean of 0 and a std dev of 1.5, puts very little weight on -4.

I mean when you look at all_groups_mean in the posterior, it’s actually what the model determines to be the prior to be since it is a hyper-prior.