Hierarchical Multinomial - die with funny readback

Ok here is an odd one that I could use some guidance on.

In the end I will generalize this to outputs from an instrument. To get there in an understandable way consider that we have an unfair 4-sided die but instead of knowing how many 1s, 2s, 3s, 4s were rolled (like in this example) we know how many time we got even, 2, and 3. I am thinking this should allow for the estimation of the probability of getting, 1, 2, 3, 4.

I am thinking of this in a hierarchical Multinomial kinda way but my logic is not closing into code I can write.

A roll is either even or odd, then even is 2, 4 and or is either 1, 3. That forms a tree
Roll -> Even -> 2
__________ -> 4
____-> Odd -> 1
__________ -> 3

Which seems like it should be buildable in such a way that the observations can be any of the items in the table.

Anyone have any similar examples or thoughts on how to poke at this?

So I got here in one way that seems to work. Anyone have any thoughts on improvement?

# observe 1,2,3,4 to get right answer
# fair fwould be all the same, we want different
# 1/4 = 0.25, so we want [10%, 20%, 30%, 40%]
observations = np.array([50*.1, 50*.2, 50*.3 ,50*.4])
with pm.Model():
    probs = pm.Dirichlet('probs', a=np.ones(4))  # flat prior
    rolls = pm.Multinomial('rolls', n=observations.sum(), p=probs, observed=observations)
    trace = pm.sample(5000)
# pm.traceplot(trace)
pm.plot_posterior(trace);

The just build up the whole thing from chained Bernoulli distributions. Where the data observed are evens, 2s, 1s

# this is the number of 1,2,3,4
N_rolls = 50
observations = np.array([N_rolls*.1, N_rolls*.2, N_rolls*.3 ,N_rolls*.4])
# so the data for even and odd is, even = 1
obs_evenodd = [1]*observations[np.asarray([1, 3])].sum().astype(int) + [0]*observations[np.asarray([0, 2])].sum().astype(int)
# make then obs_2_4 from even, 2=True
obs_2_4 = [1]*observations[1].astype(int) + [0]*observations[3].astype(int)
# make then obs_1_3 from even, 1=True
obs_1_3 = [1]*observations[0].astype(int) + [0]*observations[2].astype(int)

with pm.Model() as our_first_model:
    p_evenodd = pm.Beta('p_evenodd', alpha=1, beta=1)
    evenodd = pm.Bernoulli('evenodd', p=p_evenodd, observed=obs_evenodd)
    
    p_2_4 = pm.Beta('p_2_4', alpha=1, beta=1)
    b_2_4 = pm.Bernoulli('b_2_4', p=p_2_4, observed=obs_2_4)
    
    p_1_3 = pm.Beta('p_1_3', alpha=1, beta=1)
    b_1_3 = pm.Bernoulli('b_1_3', p=p_1_3, observed=obs_1_3)
    
    
    p1 = pm.Deterministic('p1', (1-p_evenodd)*p_1_3)
    p2 = pm.Deterministic('p2', (p_evenodd)*p_2_4)
    p3 = pm.Deterministic('p3', (1-p_evenodd)*(1-p_1_3))
    p4 = pm.Deterministic('p4', (p_evenodd)*(1-p_1_3))
    
    trace = pm.sample(5000,)

pm.traceplot(trace, var_names=['p1', 'p2', 'p3', 'p4'])
pm.plot_posterior(trace, var_names=['p1', 'p2', 'p3', 'p4']);

And they give the same answer.