Hi Ricardo. One more question. I’ve got the model working for individual experiment participants now, but would like to make it hierarchical. For reference, here’s the code for my individual model.
coords = {
"participant": participants,
"time": np.arange(n_timesteps),
"state": ["xt=a", "xt=b"],
"next_state": ["xt+1=a", "xt+1=b"],
"emission": ["emission_concordant_with_a", "emission_concordant_with_b"],
}
with pm.Model(coords=coords) as model:
def step(P, x_tm1):
x_t = pm.Categorical.dist(p=P[x_tm1])
return x_t, collect_default_updates(x_t)
def markov_chain(x0, Ps, shape=None):
states, _ = pytensor.scan(step, outputs_info=[x0], sequences=[Ps])
return states
# Set transition matricies with informative priors
transition_matrix_baseline = pm.Dirichlet(
"transition_matrix_baseline",
a=[[5, 1.2], [1.2, 5]],
dims=["participant", "state", "next_state"],
)
transition_matrix_explicit = pm.Dirichlet(
"transition_matrix_explicit",
a=[[1.2, 5], [5, 1.2]],
dims=["participant", "state", "next_state"],
)
transition_matricies = pt.stack(
[transition_matrix_baseline] * n_timesteps
).dimshuffle(1, 0, 2, 3)
for participant_index, explicit_transition_indicies in enumerate(
model_data["explicit_transition_indicies"]
):
transition_matricies = transition_matricies[
participant_index, explicit_transition_indicies, :, :
].set(transition_matrix_explicit[participant_index, :, :])
transition_matricies = pm.Deterministic(
"transition_matricies",
transition_matricies,
dims=["participant", "time", "state", "next_state"],
)
p_x0s = pm.Data(
"initial_state",
[0.1 if x == "a" else 0.9 for x in model_data["initial_state"]],
dims="participant",
)
x0 = pm.Bernoulli("x0", p=p_x0s, dims="participant")
hidden_states = pm.CustomDist(
"hidden_states",
x0,
transition_matricies,
dist=markov_chain,
dims=["participant", "time"],
)
observed_emissions = pm.Data(
"obsereved_emissions",
model_data["cue_concordant_with_a"],
dims=["participant", "time"],
)
emission_ps = pm.Dirichlet(
"p_emission_concordant_with_true_state",
a=[[3, 1], [1, 3]],
dims=["participant", "state", "emission"],
)
emission_matricies = pm.math.stack([emission_ps] * n_timesteps)
emissions = pm.Categorical(
"emissions",
p=emission_matricies[0, hidden_states],
dims=["participant", "time"],
observed=observed_emissions,
)
# Specify known states
known_indices = np.where(~np.isnan(known_states))[0]
is_correct_state = pt.eq(hidden_states[known_indices], known_states[known_indices])
logp = pt.sum(pt.switch(is_correct_state, -0.02, -3.91))
pm.Potential("hidden_states_observed", logp)
To convert it to a hierarchical model, I’ve added the extra dimension, participant
to the distributions and data, but I’m confused about how to adapt my scan function to cope with it. I’ve looked at some similar questions on the discourse, but none quite answer this question. Below is the code for the hierarchical model which (hopefully) only needs the scan function adapted to work.
model_data = { #Each value is a list of arrays
"participants": participants,
"explicit_transition_indicies": explicit_transition_index_arrays,
"initial_states": initial_states_array,
"length": length_arrays,
"known_states": known_state_arrays,
"cue_concordant_with_a": cue_concordant_with_a_arrays,
}
coords = {
"participant": participants,
"time": np.arange(n_timesteps),
"state": ["xt=a", "xt=b"],
"next_state": ["xt+1=a", "xt+1=b"],
"emission": ["emission_concordant_with_a", "emission_concordant_with_b"],
}
with pm.Model(coords=coords) as model:
def step(P, x_tm1):
x_t = pm.Categorical.dist(p=P[x_tm1])
return x_t, collect_default_updates(x_t)
def markov_chain(x0, Ps, shape=None):
states, _ = pytensor.scan(step, outputs_info=[x0], sequences=[Ps])
return states
# Set transition matricies with informative priors
transition_matrix_baseline = pm.Dirichlet(
"transition_matrix_baseline",
a=[[5, 1.2], [1.2, 5]],
dims=["participant", "state", "next_state"],
)
transition_matrix_explicit = pm.Dirichlet(
"transition_matrix_explicit",
a=[[1.2, 5], [5, 1.2]],
dims=["participant", "state", "next_state"],
)
transition_matricies = pt.stack(
[transition_matrix_baseline] * n_timesteps
).dimshuffle(1, 0, 2, 3)
for participant_index, explicit_transition_indicies in enumerate(
model_data["explicit_transition_indicies"]
):
transition_matricies = transition_matricies[
participant_index, explicit_transition_indicies, :, :
].set(transition_matrix_explicit[participant_index, :, :])
transition_matricies = pm.Deterministic(
"transition_matricies",
transition_matricies,
dims=["participant", "time", "state", "next_state"],
)
p_x0s = pm.Data(
"initial_state",
[0.1 if x == "a" else 0.9 for x in model_data["initial_state"]],
dims="participant",
)
x0 = pm.Bernoulli("x0", p=p_x0s, dims="participant")
hidden_states = pm.CustomDist(
"hidden_states",
x0,
transition_matricies,
dist=markov_chain,
dims=["participant", "time"],
)
observed_emissions = pm.Data(
"obsereved_emissions",
model_data["cue_concordant_with_a"],
dims=["participant", "time"],
)
emission_ps = pm.Dirichlet(
"p_emission_concordant_with_true_state",
a=[[3, 1], [1, 3]],
dims=["participant", "state", "emission"],
)
emission_matricies = pm.math.stack([emission_ps] * n_timesteps)
emissions = pm.Categorical(
"emissions",
p=emission_matricies[0, hidden_states],
dims=["participant", "time"],
observed=observed_emissions,
)
# Specify known states
known_states = np.stack(model_data["known_states"])
known_indices = np.where(~np.isnan(known_states))[0]
is_correct_state = pt.eq(hidden_states[known_indices], known_states[known_indices])
logp = pt.sum(pt.switch(is_correct_state, -0.02, -3.91))
pm.Potential("hidden_states_observed", logp)
Are you able to point me in the right direction to get the hierarchical scan working?