Model oversamples low probability event

Perhaps I didn’t copy it well? Here it goes again

import pymc as pm
import numpy as np


n_subjects = 100
n_trials = 100

n_states = 3

subject_noise = 0.5
trial_noise = 0.5

state_probabilities = np.array([0.1, 0.4, 0.5])


coords = {
    "subject": np.arange(n_subjects),
    "trial": np.arange(n_trials),
    "state": np.arange(n_states),
}


with pm.Model(coords=coords) as model:
    initial_log_p = pm.Deterministic(
        "initial_log_p", pm.math.log(state_probabilities), dims="state"
    )

    subj_noise = pm.Normal("subj_noise", 0, subject_noise, dims=("subject"))
    trial_noise = pm.Normal("trial_noise", 0, trial_noise, dims=("subject", "trial"))

    total_noise = pm.Deterministic(
        "total_noise", subj_noise[:, None] + trial_noise, dims=("subject", "trial")
    )

    final_log = pm.Deterministic(
        "final_log", initial_log_p[None, None, :] + total_noise[..., None], dims=("subject", "trial", "state")
    )

    final_p_non_sum = pm.Deterministic(
        "final_p_non_sum", pm.math.softmax(final_log, axis=-1), dims=("subject", "trial", "state")
    )

    dirichlet_sample = pm.Dirichlet(
        "dirichlet_sample", a=final_p_non_sum * 1, dims=("subject", "trial", "state")
    )

    sample_state = pm.Multinomial(
        "sample_state", n=1, p=dirichlet_sample, dims=("subject", "trial", "state")
    )

    prior = pm.sample_prior_predictive(samples=1, random_seed=1)

prior.prior["final_p_non_sum"].sum("state").to_numpy()