Feedback on RCT model

Hi,

I’m new to pymc3, and I am seeking feedback on the following model. The goal is to model open rates (binomial) for RCTs. In addition to a trial and control group, there are also distributions (email batches) sent out at different times.

Any feedback, critique, or suggestions for improvement are very much appreciated. Specifically, it would be helpful to know that the model structure makes sense, and if so, what a reasonable sigma value for the baseline open rate (b0 in the logistic) would be.

Model:

data = pd.DataFrame({'trial_control': [0, 1, 0, 1, 0, 1],
                     'distribution': [0, 0, 1, 1, 2, 2],
                     'sent': [100, 500, 1000, 5000, 10000, 50000],
                     'opens': [20, 150, 200, 1500, 2000, 15000]})

# Set number of distributions and trial/control and distribution indexes
tc_idx = data['trial_control'].values
n_tc = data['trial_control'].nunique()
dist_idx = data['distribution'].values
n_dist = data['distribution'].nunique()

# Priors
base_open_rate_mu = -np.log(4)  # logistic(-np.log(4)) = 20%
base_open_rate_sigma = 0.125
trial_control_mu = 0
trial_control_sigma = 0.5
distribution_mu = 0
distribution_sigma = 0.5


def logistic(x):
    return 1 / (1 + np.exp(-1 * x))


with pm.Model() as open_model:
    # Intercept for logisitc fn (baseline open rate: 0.2 = logisitc(-ln(4))).
    b0 = pm.Normal('b0', base_open_rate_mu, base_open_rate_sigma)

    tc_beta = pm.Normal('trial_control_beta', trial_control_mu, trial_control_sigma, shape=n_tc)
    dist_beta = pm.Normal('distribution_beta', distribution_mu, distribution_sigma, shape=n_dist)

    # Logistic (outputs P(open))
    theta = pm.Deterministic('theta', logistic(b0 + tc_beta[tc_idx] + dist_beta[dist_idx]))
    
    # Data likelihood
    likelihood_opens = pm.Binomial('data', p=theta, n=data['sent'].values, observed=data['opens'].values)


# Inference to approximate the posterior distribution
with open_model:
    trace = pm.sample(draws=5000, tune=1000, init='auto')

print(pm.summary(trace))
pm.traceplot(trace)