[update]
using logsumexp() seems to work OK
Q: is there a better way to do the same thing?
with pm.Model() as model_oxen:
# priors
p_cheat = pm.Beta('p_cheat', 2,2) # prob drink tea if ox not stabled
p_drink = pm.Beta('p_drink', 2,2) # prob drink tea if ox stabled
p_stabled = pm.Beta('p_stabled', 2,2) # prob kid stabled ox
for idx in np.arange(N_kids):
if s_obs[idx]==2:
# ox status not observed
pr_tea_ox = pm.Bernoulli.dist(p_drink).logp(tea_obs[idx])
pr_tea_noox = pm.Bernoulli.dist(p_cheat).logp(tea_obs[idx])
a = pm.math.log(p_stabled)+pr_tea_ox
b = pm.math.log(1-p_stabled)+pr_tea_noox
_likelihood_s_not_obs = pm.math.logsumexp([a, b])
likelihood_s_not_obs = pm.Potential('likelihood_s_not_obs_{0:03d}'.format(idx), _likelihood_s_not_obs)
else:
# ox status observed
_pi = pm.Deterministic('pi_{0:03d}'.format(idx), s_obs[idx]*p_drink+(1-s_obs[idx])*p_cheat)
likelihood_tea = pm.Bernoulli('likelihood_tea_{0:03d}'.format(idx), _pi, observed=tea_obs[idx])
likelihood_s_obs = pm.Bernoulli('likelihood_s_obs_{0:03d}'.format(idx), p_stabled, observed=s_obs[idx])