Porting stan oxen example to pymc3

[update]
using logsumexp() seems to work OK

Q: is there a better way to do the same thing?

with pm.Model() as model_oxen:
    # priors
    p_cheat = pm.Beta('p_cheat', 2,2) # prob drink tea if ox not stabled
    p_drink = pm.Beta('p_drink', 2,2) # prob drink tea if ox stabled
    p_stabled =  pm.Beta('p_stabled', 2,2) # prob kid stabled ox
    
    for idx in np.arange(N_kids):        
        if s_obs[idx]==2:
            # ox status not observed            
            pr_tea_ox = pm.Bernoulli.dist(p_drink).logp(tea_obs[idx])
            pr_tea_noox = pm.Bernoulli.dist(p_cheat).logp(tea_obs[idx])

            a = pm.math.log(p_stabled)+pr_tea_ox
            b = pm.math.log(1-p_stabled)+pr_tea_noox
            _likelihood_s_not_obs = pm.math.logsumexp([a, b])
            likelihood_s_not_obs = pm.Potential('likelihood_s_not_obs_{0:03d}'.format(idx), _likelihood_s_not_obs)
        else:
            # ox status observed
            _pi = pm.Deterministic('pi_{0:03d}'.format(idx), s_obs[idx]*p_drink+(1-s_obs[idx])*p_cheat)
            likelihood_tea = pm.Bernoulli('likelihood_tea_{0:03d}'.format(idx), _pi, observed=tea_obs[idx])
            likelihood_s_obs = pm.Bernoulli('likelihood_s_obs_{0:03d}'.format(idx), p_stabled, observed=s_obs[idx])