Perhaps I didn’t copy it well? Here it goes again
import pymc as pm
import numpy as np
n_subjects = 100
n_trials = 100
n_states = 3
subject_noise = 0.5
trial_noise = 0.5
state_probabilities = np.array([0.1, 0.4, 0.5])
coords = {
"subject": np.arange(n_subjects),
"trial": np.arange(n_trials),
"state": np.arange(n_states),
}
with pm.Model(coords=coords) as model:
initial_log_p = pm.Deterministic(
"initial_log_p", pm.math.log(state_probabilities), dims="state"
)
subj_noise = pm.Normal("subj_noise", 0, subject_noise, dims=("subject"))
trial_noise = pm.Normal("trial_noise", 0, trial_noise, dims=("subject", "trial"))
total_noise = pm.Deterministic(
"total_noise", subj_noise[:, None] + trial_noise, dims=("subject", "trial")
)
final_log = pm.Deterministic(
"final_log", initial_log_p[None, None, :] + total_noise[..., None], dims=("subject", "trial", "state")
)
final_p_non_sum = pm.Deterministic(
"final_p_non_sum", pm.math.softmax(final_log, axis=-1), dims=("subject", "trial", "state")
)
dirichlet_sample = pm.Dirichlet(
"dirichlet_sample", a=final_p_non_sum * 1, dims=("subject", "trial", "state")
)
sample_state = pm.Multinomial(
"sample_state", n=1, p=dirichlet_sample, dims=("subject", "trial", "state")
)
prior = pm.sample_prior_predictive(samples=1, random_seed=1)
prior.prior["final_p_non_sum"].sum("state").to_numpy()