Porting stan oxen example to pymc3

Hello! I’m learning pymc3 and probablisitic programming (many thanks for the excellent docs and videos, esp Junpeng Lao’s work)

I came across this Stan code example and wanted to reimplement it in pymc3
McElreath oxen blog example

Short version of problem description:

  • Kids like to drink tea and are allowed to have an evening tea if they’ve stabled the family oxen at the end of the day.
  • Each evening, find which children have properly stabled their oxen. For many houses, you can see whether or not the ox is stabled. But other houses have enclosed stables, so you cannot observe the ox
  • Some kids drink tea even if ox is not stabled

Want to find Pr(ox-not-stabled | drinking tea)
priors are Beta(2,2) for:
p_drink # Pr(drink-tea|ox stabed)
p_cheat # Pr(drink-tea|ox not stabled)
p_stabled # Pr(ox stabled)

tea_i ~ Bernoulli ( pi_i )
pi_i = s_i * p_drink + (1-s_i) *p_cheat
s_i ~ Bernoulli(p_stabled) # observe whether ox is stabled or not, some missing data

I’m having issues handling the partially observed discrete variable ‘s_obs’ (whether the ox-status is observed)
My results don’t match the blog’s. I suspect the issue is with the (*****) section below.

What’s the appropriate way to perform the marginalization in pymc3?

This is my implementation

# %% 
N_kids = 51

# use same fake data as blog  [0, 1, 2]  whether ox is stabled (1) or not (0) or not observed (2)
s = np.array([1,1,1,0,1,0,0,1,1,1,1,1,1,1,0,1,1,0,1,0,0,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,0,1,0,1,1,0,1,1,1,1,1])
s_obs = np.array([1,1,2,2,2,0,0,1,1,2,2,2,2,1,2,1,1,2,2,0,0,2,1,2,2,1,2,1,2,2,1,1,1,1,2,1,0,1,2,1,2,1,2,2,1,0,1,1,1,1,1])

s_not_obs_idx = np.argwhere(s_obs==2)
s_obs_idx = np.argwhere(s_obs!=2)

tea_obs = np.array([1,1,1,0,1,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1,1,1])


with pm.Model() as model_oxen:
    # priors
    p_cheat = pm.Beta('p_cheat', 2,2) # drink tea if ox not stabled
    p_drink = pm.Beta('p_drink', 2,2) # drink tea if ox stabled
    p_stabled =  pm.Beta('p_stabled', 2,2) # kid stabled ox

    for idx in np.arange(N_kids):        
        if s_obs[idx]==2:
            # ox status not observed (*****)
            pr_tea_ox = pm.Bernoulli('pr_tea_ox_{0:03d}'.format(idx), p_drink, observed=tea_obs[idx])
            pr_tea_noox = pm.Bernoulli('pr_tea_noox_{0:03d}'.format(idx), p_cheat,observed=tea_obs[idx])

            _likelihood_s_not_obs = p_stabled*pr_tea_ox + (1-p_stabled)*pr_tea_noox
            likelihood_s_not_obs = pm.Potential('likelihood_s_not_obs_{0:03d}'.format(idx), _likelihood_s_not_obs)
            # (this also does not work)  likelihood_s_not_obs = pm.Deterministic('likelihood_s_not_obs_{0:03d}'.format(idx), _likelihood_s_not_obs)
        else:
            # ox status observed
            _pi = pm.Deterministic('pi_{0:03d}'.format(idx), s_obs[idx]*p_drink+(1-s_obs[idx])*p_cheat)
            likelihood_tea = pm.Bernoulli('likelihood_tea_{0:03d}'.format(idx), _pi, observed=tea_obs[idx])
            likelihood_s_obs = pm.Bernoulli('likelihood_s_obs_{0:03d}'.format(idx), p_stabled, observed=s_obs[idx])

# %%
with model_oxen:
    trace = pm.sample(1000, init=None, tune=500)
pm.summary(trace, var_names=['p_cheat', 'p_drink', 'p_stabled'])


This is the stan model from the blog:

model{
  // priors
  p_cheat ~ beta(2,2);
  p_drink ~ beta(2,2);
  sigma ~ beta(2,2);

  // probability of tea
  for ( i in 1:N_children ) {
    if ( s[i] == -1 ) {
      // ox unobserved
      // log_mix(): log( sigma*bernoulli(tea[i]|p_drink) + (1-sigma)*bernoulli(tea[i]|p_cheat) )
      target += log_mix( 
                  sigma , 
                  bernoulli_lpmf( tea[i] | p_drink ) , 
                  bernoulli_lpmf( tea[i] | p_cheat ) );
    } else {
      // ox observed
      tea[i] ~ bernoulli( s[i]*p_drink + (1-s[i])*p_cheat );
      s[i] ~ bernoulli( sigma );
    }
  }//i
}

[update]
using logsumexp() seems to work OK

Q: is there a better way to do the same thing?

with pm.Model() as model_oxen:
    # priors
    p_cheat = pm.Beta('p_cheat', 2,2) # prob drink tea if ox not stabled
    p_drink = pm.Beta('p_drink', 2,2) # prob drink tea if ox stabled
    p_stabled =  pm.Beta('p_stabled', 2,2) # prob kid stabled ox
    
    for idx in np.arange(N_kids):        
        if s_obs[idx]==2:
            # ox status not observed            
            pr_tea_ox = pm.Bernoulli.dist(p_drink).logp(tea_obs[idx])
            pr_tea_noox = pm.Bernoulli.dist(p_cheat).logp(tea_obs[idx])

            a = pm.math.log(p_stabled)+pr_tea_ox
            b = pm.math.log(1-p_stabled)+pr_tea_noox
            _likelihood_s_not_obs = pm.math.logsumexp([a, b])
            likelihood_s_not_obs = pm.Potential('likelihood_s_not_obs_{0:03d}'.format(idx), _likelihood_s_not_obs)
        else:
            # ox status observed
            _pi = pm.Deterministic('pi_{0:03d}'.format(idx), s_obs[idx]*p_drink+(1-s_obs[idx])*p_cheat)
            likelihood_tea = pm.Bernoulli('likelihood_tea_{0:03d}'.format(idx), _pi, observed=tea_obs[idx])
            likelihood_s_obs = pm.Bernoulli('likelihood_s_obs_{0:03d}'.format(idx), p_stabled, observed=s_obs[idx])

You should use pm.Mixture for a direct translation of log_mix in Stan :slight_smile:

Thanks Junpeng,
I had some trouble with casting to theano tensors, but the new model below works OK

Q: are the two models identical? I’m getting more accurate results with model_oxen2 (tighter HPDI, mean is closer to ideal value)

with pm.Model() as model_oxen2:
    # priors
    p_cheat = pm.Beta('p_cheat', 2, 2) # prob drink tea if ox not stabled
    p_drink = pm.Beta('p_drink', 2, 2) # prob drink tea if ox stabled
    p_stabled =  pm.Beta('p_stabled', 2, 2) # prob kid stabled ox

    # ox status not observed, marginalize out ox / no-ox
    pr_tea_ox = pm.Bernoulli.dist(p=p_drink)
    pr_tea_noox = pm.Bernoulli.dist(p=p_cheat)

    _w = T.stack([p_stabled, 1 - p_stabled])
    _comp_dists = [pr_tea_ox, pr_tea_noox]

    likelihood_s_not_obs = pm.Mixture('likelihood_s_not_obs', _w, comp_dists = _comp_dists, observed=tea_obs[s_not_obs_idx])

    # ox status observed
    _pi = pm.Deterministic('pi', s_obs[s_obs_idx]*p_drink+(1-s_obs[s_obs_idx])*p_cheat)
    likelihood_tea = pm.Bernoulli('likelihood_tea', _pi, observed=tea_obs[s_obs_idx])
    likelihood_s_obs = pm.Bernoulli('likelihood_s_obs', p_stabled, observed=s_obs[s_obs_idx])