Unable to Replicate the sampling speed given in the example

Thanks a lot @jessegrabowski! Running the commands from Anaconda Prompt has solved the installation error for pymc4. Additionally, I am able to replicate the sampling speed of the example I mentioned in the previous comment.

Interestingly, the same model compilation in pymc3 produced the following error:
theano_compilation_error_zjcos1le.txt (32.4 KB)
I am not worried about running pymc3 as long as pymc4 is running fine.

As I have previously mentioned I am working with discrete variables and my goal is to implement the Dirichlet Process Mixture (DPM). I have started with the following finite mixture example. In the example, I am generating data from a mixture of 5 normals with common variance.

N = 500
mu = [-8, -3, 0, 3, 8]
sigma = 2
alloc_num = np.dot([0.15, 0.3, 0.25, 0.2, 0.1], N).astype(int)

d1 = np.random.normal(mu[0], sigma, alloc_num[0])
d2 = np.random.normal(mu[1], sigma, alloc_num[1])
d3 = np.random.normal(mu[2], sigma, alloc_num[2])
d4 = np.random.normal(mu[3], sigma, alloc_num[3])
d5 = np.random.normal(mu[4], sigma, alloc_num[4])
data = np.array([*d1, *d2, *d3, *d4, *d5])

I have used pm.sampling_jax.sample_numpyro_nuts(draws = 2000, tune = 2000, chains = 4, progress_bar=True) with the following model:

def build_model(pm):
    with pm.Model() as hierarchical_model:
        mu_vec = pm.Normal("mu_vec", mu=0.0, sigma=15.0, shape = 5)
        common_sigma = pm.HalfNormal("common_sigma", sigma = 5.0)
        
        components = pm.Normal.dist(mu=mu_vec, shape=5)
        pi = pm.Dirichlet("pi", a = np.array([1, 1, 1, 1, 1]))
     
        like = pm.Mixture('like', w=pi, comp_dists=components, observed=data)
        
    return hierarchical_model

It ran in 11 sec! Next, I wanted to try the same thing with a latent discrete indicator which indicates the component of the mixture for each data point. The code looks like this:

def build_model(pm):
    with pm.Model() as hierarchical_model:
        mu_vec = pm.Normal("mu_vec", mu=0.0, sigma=15.0, shape = 5)
        common_sigma = pm.HalfNormal("common_sigma", sigma = 5.0)
       
        pi = pm.Dirichlet("pi", a = np.array([1, 1, 1, 1, 1]))
        ind = pm.Categorical("ind", pi, shape = 500)
        
        radon_like = pm.Normal(
            "radon_like", mu=mu_vec[ind], sigma=common_sigma, observed=data
        )
        
    return hierarchical_model

I know that the code is not compatible with numpyro_nuts therefore I used the standard sampling: pm.sample(draws = 2000, tune = 2000, chains = 4, progress_bar=False, target_accept=0.9). This took 8 minutes to run which is comparatively very slow. As this includes latent discrete variables, can you kindly tell me what sampler it is using? I have read that the pymc vanilla samplers can combine NUTS with metropolis steppers for discrete variables. Kindly enlighten me regarding the best way to implement this.

I understand marginalization is a solution but in my original problem, marginalizing over 20 components requires the computation of the likelihood 20 times for each of the 20 different parameters for each subject. Unfortunately, each likelihood computations are costly. This is the reason why I am interested in the implementation using latent discrete indicators.