Hello! I’m learning pymc3 and probablisitic programming (many thanks for the excellent docs and videos, esp Junpeng Lao’s work)
I came across this Stan code example and wanted to reimplement it in pymc3
McElreath oxen blog example
Short version of problem description:
- Kids like to drink tea and are allowed to have an evening tea if they’ve stabled the family oxen at the end of the day.
- Each evening, find which children have properly stabled their oxen. For many houses, you can see whether or not the ox is stabled. But other houses have enclosed stables, so you cannot observe the ox
- Some kids drink tea even if ox is not stabled
Want to find Pr(ox-not-stabled | drinking tea)
priors are Beta(2,2) for:
p_drink # Pr(drink-tea|ox stabed)
p_cheat # Pr(drink-tea|ox not stabled)
p_stabled # Pr(ox stabled)
tea_i ~ Bernoulli ( pi_i )
pi_i = s_i * p_drink + (1-s_i) *p_cheat
s_i ~ Bernoulli(p_stabled) # observe whether ox is stabled or not, some missing data
I’m having issues handling the partially observed discrete variable ‘s_obs’ (whether the ox-status is observed)
My results don’t match the blog’s. I suspect the issue is with the (*****) section below.
What’s the appropriate way to perform the marginalization in pymc3?
This is my implementation
# %%
N_kids = 51
# use same fake data as blog [0, 1, 2] whether ox is stabled (1) or not (0) or not observed (2)
s = np.array([1,1,1,0,1,0,0,1,1,1,1,1,1,1,0,1,1,0,1,0,0,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,0,1,0,1,1,0,1,1,1,1,1])
s_obs = np.array([1,1,2,2,2,0,0,1,1,2,2,2,2,1,2,1,1,2,2,0,0,2,1,2,2,1,2,1,2,2,1,1,1,1,2,1,0,1,2,1,2,1,2,2,1,0,1,1,1,1,1])
s_not_obs_idx = np.argwhere(s_obs==2)
s_obs_idx = np.argwhere(s_obs!=2)
tea_obs = np.array([1,1,1,0,1,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1,1,1])
with pm.Model() as model_oxen:
# priors
p_cheat = pm.Beta('p_cheat', 2,2) # drink tea if ox not stabled
p_drink = pm.Beta('p_drink', 2,2) # drink tea if ox stabled
p_stabled = pm.Beta('p_stabled', 2,2) # kid stabled ox
for idx in np.arange(N_kids):
if s_obs[idx]==2:
# ox status not observed (*****)
pr_tea_ox = pm.Bernoulli('pr_tea_ox_{0:03d}'.format(idx), p_drink, observed=tea_obs[idx])
pr_tea_noox = pm.Bernoulli('pr_tea_noox_{0:03d}'.format(idx), p_cheat,observed=tea_obs[idx])
_likelihood_s_not_obs = p_stabled*pr_tea_ox + (1-p_stabled)*pr_tea_noox
likelihood_s_not_obs = pm.Potential('likelihood_s_not_obs_{0:03d}'.format(idx), _likelihood_s_not_obs)
# (this also does not work) likelihood_s_not_obs = pm.Deterministic('likelihood_s_not_obs_{0:03d}'.format(idx), _likelihood_s_not_obs)
else:
# ox status observed
_pi = pm.Deterministic('pi_{0:03d}'.format(idx), s_obs[idx]*p_drink+(1-s_obs[idx])*p_cheat)
likelihood_tea = pm.Bernoulli('likelihood_tea_{0:03d}'.format(idx), _pi, observed=tea_obs[idx])
likelihood_s_obs = pm.Bernoulli('likelihood_s_obs_{0:03d}'.format(idx), p_stabled, observed=s_obs[idx])
# %%
with model_oxen:
trace = pm.sample(1000, init=None, tune=500)
pm.summary(trace, var_names=['p_cheat', 'p_drink', 'p_stabled'])
This is the stan model from the blog:
model{
// priors
p_cheat ~ beta(2,2);
p_drink ~ beta(2,2);
sigma ~ beta(2,2);
// probability of tea
for ( i in 1:N_children ) {
if ( s[i] == -1 ) {
// ox unobserved
// log_mix(): log( sigma*bernoulli(tea[i]|p_drink) + (1-sigma)*bernoulli(tea[i]|p_cheat) )
target += log_mix(
sigma ,
bernoulli_lpmf( tea[i] | p_drink ) ,
bernoulli_lpmf( tea[i] | p_cheat ) );
} else {
// ox observed
tea[i] ~ bernoulli( s[i]*p_drink + (1-s[i])*p_cheat );
s[i] ~ bernoulli( sigma );
}
}//i
}