Using pymc3 to compute credible intervals for dependent proportions

I realize this is quite an old question, but I ran into the same question and couldn’t find a good answer elsewhere. Here is my solution, which I feel is reasonable, but I’d appreciate input if I’ve overlooked something. In particular, is this model appropriate to then look at each pairwise difference between treatments and determine if the p(treatment effect >0)>0.95, e.g.?

In my case, I needed to determine if two treatments led to different chemical makeups of a sample, defined as the relative proportion of ~27 different compounds. The key distinction between this and, for instance, this example which uses DirichletMultinomial as the likelihood, is that in my cases the proportions are observed directly, rather than arising out of individual counts. So for me (and I suspect the OP), Multinomial or DirichletMultinomial are not appropriate for the likelihood, instead the Dirichlet .

Here’s a toy example, and my implementation of the model:

import numpy as np
import pymc3 as pm
import arviz as az
from scipy import stats as st

# generate different sets of proportions from a Dirichlet distribution
p1 = st.dirichlet.rvs((20, 10, 10), 10000)
p2 = st.dirichlet.rvs((5, 10, 20), 10000)

# number of different proportions/buckets/classes
n = 3

# stack all observations together
data = np.vstack([p1[:50], p2[:50]])
is_2 = np.hstack([np.zeros(50), np.ones(50)]).astype(int)

with pm.Model() as model:
    u_h = pm.Dirichlet('p_h', np.ones(n, dtype=np.float32), shape=(n,))
    v_h = pm.Lognormal('v_h', mu=1, sigma=1, shape=(1,))
    
    u_1 = pm.Dirichlet('p_1', a=u_h*v_h, shape=(n,))
    v_1 = pm.Lognormal('v_1', mu=v_h, sigma=1, shape=(1,))
    u_2 = pm.Dirichlet('p_2', a=u_h*v_h, shape=(n,))
    v_2 = pm.Lognormal('v_2', mu=v_h, sigma=1, shape=(1,))
    
    p_diff = pm.Deterministic('p_diff', u_1-u_2)
    
    all_props = tt.stack([u_1*v_1, u_2*v_2])
    
    p_observed = pm.Dirichlet('proportions', a=all_props[is_2, :], observed=data)
    
    idata = pm.sample(return_inferencedata=True)
    summ = az.summary(idata)

This samples quickly and gives very good results:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [v_2, p_2, v_1, p_1, v_h, p_h]

100.00% [8000/8000 00:33<00:00 Sampling 4 chains, 0 divergences]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 36 seconds.
# summ
 	mean 	sd 	hdi_3% 	hdi_97% 	mcse_mean 	mcse_sd 	ess_bulk 	ess_tail 	r_hat
p_h[0] 	0.313 	0.123 	0.081 	0.535 	0.002 	0.001 	4075.0 	2507.0 	1.0
p_h[1] 	0.308 	0.120 	0.084 	0.524 	0.002 	0.001 	4665.0 	2644.0 	1.0
p_h[2] 	0.379 	0.132 	0.149 	0.644 	0.002 	0.001 	5159.0 	3241.0 	1.0
v_h[0] 	3.498 	0.667 	2.197 	4.703 	0.010 	0.007 	4635.0 	2822.0 	1.0
p_1[0] 	0.494 	0.013 	0.472 	0.518 	0.000 	0.000 	4208.0 	2882.0 	1.0
p_1[1] 	0.242 	0.011 	0.223 	0.262 	0.000 	0.000 	3833.0 	2857.0 	1.0
p_1[2] 	0.263 	0.011 	0.244 	0.284 	0.000 	0.000 	5095.0 	3219.0 	1.0
v_1[0] 	31.114 	4.344 	23.202 	39.385 	0.062 	0.044 	4881.0 	3249.0 	1.0
p_2[0] 	0.150 	0.008 	0.136 	0.167 	0.000 	0.000 	4282.0 	3268.0 	1.0
p_2[1] 	0.289 	0.011 	0.268 	0.309 	0.000 	0.000 	4419.0 	3677.0 	1.0
p_2[2] 	0.560 	0.012 	0.538 	0.584 	0.000 	0.000 	5497.0 	3280.0 	1.0
v_2[0] 	33.671 	4.639 	25.383 	42.550 	0.067 	0.048 	4828.0 	2792.0 	1.0
p_diff[0] 	0.344 	0.015 	0.316 	0.373 	0.000 	0.000 	4288.0 	3109.0 	1.0
p_diff[1] 	-0.047 	0.015 	-0.075 	-0.018 	0.000 	0.000 	4267.0 	3351.0 	1.0
p_diff[2] 	-0.297 	0.016 	-0.328 	-0.267 	0.000 	0.000 	5483.0 	3599.0 	1.0
# ground truth

(p1[:50]-p2[:50]).mean(0)
# array([ 0.34434253, -0.04815737, -0.29618517])

p1[:50].mean(0)
# array([0.49526835, 0.24025454, 0.2644771 ])

p2[:50].mean(0)
# array([0.15092582, 0.28841191, 0.56066227])
1 Like