Mixture of Bernoullis

I am trying to create the following experiment. I have two bags of two coins each. Each coin has a different probability of returning heads. Below, I list my code, but I do not get the correct probabilities. This code is meant to serve as a framework to test multivariate imputation, which pymc cannot handle “yet”. Here is my code. Is everything done correctly? I cannot find anything wrong with it.

import pymc as pm  # pymc4
import pandas as pd
import numpy as np
import pymc.distributions.discrete as discrete
import pymc.distributions.continuous as continuous
import aesara.tensor as tt
import aesara as ae  # need version 2.7.1

import pymc.sampling_jax as jx

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.default_backend(), jax.local_device_count()
print(jax.numpy.ones(3).device()) # TFRT_CPU_0

def create_data():
    # Data Generation
    p_bag = {}
    p_bag[0] = (0.4, 0.2)  # bag 1
    p_bag[1] = (0.5, 0.8)  # bag 2
    probs = [p_bag[i] for i in range(2)] 
    pi = (0.3, 0.7)   # 30% of the time, flip coins in bag 1, 70% of the time, flip coins in bag 2
    nb_flips = 5000
    
    bags = np.random.choice(2, size=nb_flips, p=pi)
    
    # select the bag of coins with probability pi (shape = 2)
    coins_flips = []  # list of H/T for coin1, and another list for coin 2

    coin0_bag = {}
    coin1_bag = {}
    
    # For efficiency, generate all the flip data ahead of time
    for i in range(2):  # nb bags: 2
        coin0_bag[i] = np.random.choice(2, size=nb_flips, p=(p_bag[i][0], 1-p_bag[i][0])) # coin 0 in bag 0
        coin1_bag[i] = np.random.choice(2, size=nb_flips, p=(p_bag[i][1], 1-p_bag[i][1])) # coin 0 in bag 0
            
    for flip in range(nb_flips):
        bag = bags[flip]
        prob = probs[bag]
        coin0 = coin0_bag[bag][flip] 
        coin1 = coin1_bag[bag][flip]
        coins_flips.append([coin0, coin1])
        
    # Create a missing value
    # coins_flips[10][0] = np.NaN

    coins_flips = np.array(coins_flips)
    print(coins_flips.shape)
    return coins_flips

def create_coin(data):
    """ 
    Two bags of two coins, flipped several times
    """
    pi = [0.3, 0.7]
    with pm.Model() as coin_model:
        # All probabilities have beta priors
        p_coins_bag1 = pm.Beta("p_bag1", 2, 2, shape=2)
        p_coins_bag2 = pm.Beta("p_bag2", 2, 2, shape=2)
        components = [
            pm.Bernoulli.dist(p=p_coins_bag1),
            pm.Bernoulli.dist(p=p_coins_bag2)
        ]
        pm.Mixture("mix", w=pi, comp_dists=components, observed=data)
    return coin_model

nb_chains = 2
nb_draws = 2000
data = create_data()
model = create_coin(data)

%%time

%%time

with model:
    # tune is nb warmups
    idata = jx.sample_numpyro_nuts(target_accept=0.9, draws=nb_draws, tune=nb_draws, chains=nb_chains, chain_method='parallel')
    # idata = jx.sample_blackjax_nuts(target_accept=0.9, draws=nb_draws, tune=0, chains=4, chain_method='vectorized') 

# RESULTS
d = idata.posterior["p_bag1"].values
print("bag 1 posteriors: ", np.mean(d, axis=1))

d = idata.posterior["p_bag2"].values
print("bag 2 posteriors: ", np.mean(d, axis=1))

The resulting probabilities (of the coins in the two bags), should match the values I provided. They do not.
Here are my results:

bag 1 posteriors:  [[0.5177841  0.46175354]
 [0.51814241 0.45773814]]
bag 2 posteriors:  [[0.53347442 0.34817736]
 [0.53375591 0.34963206]]

The correct results should be:

    p_bag[0] = (0.4, 0.2)  # bag 1
    p_bag[1] = (0.5, 0.8)  # bag 2

Perhaps I have not run enough samples? By the way, the inference runs almost two orders of magnitude more slowly than a single coin flip. I have the impression that Mixture models are rather inefficient. But I am just guessing. Thanks for any insight!