I realize this is quite an old question, but I ran into the same question and couldn’t find a good answer elsewhere. Here is my solution, which I feel is reasonable, but I’d appreciate input if I’ve overlooked something. In particular, is this model appropriate to then look at each pairwise difference between treatments and determine if the p(treatment effect >0)>0.95, e.g.?
In my case, I needed to determine if two treatments led to different chemical makeups of a sample, defined as the relative proportion of ~27 different compounds. The key distinction between this and, for instance, this example which uses DirichletMultinomial as the likelihood, is that in my cases the proportions are observed directly, rather than arising out of individual counts. So for me (and I suspect the OP), Multinomial or DirichletMultinomial are not appropriate for the likelihood, instead the Dirichlet .
Here’s a toy example, and my implementation of the model:
import numpy as np
import pymc3 as pm
import arviz as az
from scipy import stats as st
# generate different sets of proportions from a Dirichlet distribution
p1 = st.dirichlet.rvs((20, 10, 10), 10000)
p2 = st.dirichlet.rvs((5, 10, 20), 10000)
# number of different proportions/buckets/classes
n = 3
# stack all observations together
data = np.vstack([p1[:50], p2[:50]])
is_2 = np.hstack([np.zeros(50), np.ones(50)]).astype(int)
with pm.Model() as model:
u_h = pm.Dirichlet('p_h', np.ones(n, dtype=np.float32), shape=(n,))
v_h = pm.Lognormal('v_h', mu=1, sigma=1, shape=(1,))
u_1 = pm.Dirichlet('p_1', a=u_h*v_h, shape=(n,))
v_1 = pm.Lognormal('v_1', mu=v_h, sigma=1, shape=(1,))
u_2 = pm.Dirichlet('p_2', a=u_h*v_h, shape=(n,))
v_2 = pm.Lognormal('v_2', mu=v_h, sigma=1, shape=(1,))
p_diff = pm.Deterministic('p_diff', u_1-u_2)
all_props = tt.stack([u_1*v_1, u_2*v_2])
p_observed = pm.Dirichlet('proportions', a=all_props[is_2, :], observed=data)
idata = pm.sample(return_inferencedata=True)
summ = az.summary(idata)
This samples quickly and gives very good results:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [v_2, p_2, v_1, p_1, v_h, p_h]
100.00% [8000/8000 00:33<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 36 seconds.
# summ
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p_h[0] 0.313 0.123 0.081 0.535 0.002 0.001 4075.0 2507.0 1.0
p_h[1] 0.308 0.120 0.084 0.524 0.002 0.001 4665.0 2644.0 1.0
p_h[2] 0.379 0.132 0.149 0.644 0.002 0.001 5159.0 3241.0 1.0
v_h[0] 3.498 0.667 2.197 4.703 0.010 0.007 4635.0 2822.0 1.0
p_1[0] 0.494 0.013 0.472 0.518 0.000 0.000 4208.0 2882.0 1.0
p_1[1] 0.242 0.011 0.223 0.262 0.000 0.000 3833.0 2857.0 1.0
p_1[2] 0.263 0.011 0.244 0.284 0.000 0.000 5095.0 3219.0 1.0
v_1[0] 31.114 4.344 23.202 39.385 0.062 0.044 4881.0 3249.0 1.0
p_2[0] 0.150 0.008 0.136 0.167 0.000 0.000 4282.0 3268.0 1.0
p_2[1] 0.289 0.011 0.268 0.309 0.000 0.000 4419.0 3677.0 1.0
p_2[2] 0.560 0.012 0.538 0.584 0.000 0.000 5497.0 3280.0 1.0
v_2[0] 33.671 4.639 25.383 42.550 0.067 0.048 4828.0 2792.0 1.0
p_diff[0] 0.344 0.015 0.316 0.373 0.000 0.000 4288.0 3109.0 1.0
p_diff[1] -0.047 0.015 -0.075 -0.018 0.000 0.000 4267.0 3351.0 1.0
p_diff[2] -0.297 0.016 -0.328 -0.267 0.000 0.000 5483.0 3599.0 1.0
# ground truth
(p1[:50]-p2[:50]).mean(0)
# array([ 0.34434253, -0.04815737, -0.29618517])
p1[:50].mean(0)
# array([0.49526835, 0.24025454, 0.2644771 ])
p2[:50].mean(0)
# array([0.15092582, 0.28841191, 0.56066227])