Hello Amir and possibly other members seeing this 2 years later, I have solved your (or at least similar) problem in a different way, with hierarchical models.
The transformation between the Probability Mass Function (in short PMF, the posterior of the first step) and Dirichlet alpha vector (as the prior to the second step) can be done by fitting a hierarchy of Gamma
→ Dirichlet
→ Multinomial
.
-
The
Gamma
is a vector of distributions (n distributions) each modelling one entry of your alpha vector. So there is oneGamma
distribution per alpha vector entry. -
This
Gamma
vector of distributions can be then plugged to theweights = pm.Dirichlet('name', a=gammas)
. -
The
weights
containing theDirichlet
distribution is then used to model theMultinomial
distribution that you want to fit. In theMultinomial
you input samples from the prior to theobserved
field.
So to sum it up, you need to have the posterior of the first step = the Probability Mass Function. You need to take a reasonable amount of samples from it. This will serve as an observation
in the observed
field of the Multinomial
distribution in the 3rd step.
Then you can run posterior predictive check of the sampled chain and gather the means of it, which results in desired alphas based on your previous posterior.
Here I attach a pseudo-code for what I just wrote:
n_samples = len(PMF) * 100 # number of categories in the PMF * 100
obs = sample_PMF(PMF, n_samples) # the sampling can be also taken care of by pymc3, I believe
N, K = obs.shape # N - num of obs, K - num of categories (components)
with pm.Model() as dirFit:
# dirichlet distribution
alpha = pm.Gamma('alpha', alpha=2, beta=0.5, shape=K)
w = pm.Dirichlet('weights', a=alpha)
y = pm.Multinomial('y', 1, w, observed=obs)
# MCMC sampling here
trace = pm.sample(5000, tune=5000, target_accept=0.9)
burn_in = 500
chain = trace[burn_in:]
alphas = []
fit_probs = []
ppc = pm.sample_posterior_predictive(chain, var_names=['alpha', 'weights'], model=dirFit)
for i, samp in enumerate(ppc['alpha'].transpose()):
alphas.append(np.mean(samp))
print("Alphas: ", np.array(alphas))
OUTPUT:
trace = pm.sample(5000, tune=5000, target_accept=0.9)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [weights, alpha]
Sampling 4 chains for 5_000 tune and 5_000 draw iterations (20_000 + 20_000 draws total) took 17 seconds.000/40000 00:16<00:00 Sampling 4 chains, 0 divergences]
Alphas: [10.3787427 2.87890752 3.28627503 5.11770503 1.61238808 1.88795878
2.30119569]