Complex model samples very slowly - wondering about possible improvements

Thank you for your answer! I have indeed written something like what you suggested:

def theano_RSA(
        possible_signals_array=T.ltensor3("possible_signals_array"), 
        real_signals_indices=T.lvector("real_signals_indices"), 
        alphas=T.dvector("alphas"), 
        choice_alphas=T.dvector("choice_alphas"), 
        cost_factors= T.dvector("cost_factors"), 
        objective_costs_possible=T.dvector("objective_costs_possible"), 
        at_most_as_costly=T.lmatrix("at_most_as_costly"), 
        types=T.lvector("types"),
        distances=T.dmatrix("distances")):

    real_signals_array = possible_signals_array[:,real_signals_indices]
    objective_costs_real = objective_costs_possible[real_signals_indices]

    max_pic_size = possible_signals_array.shape[0] - 1
    considered_signals = at_most_as_costly & types.dimshuffle("x", 0)

    language_l = possible_signals_array / possible_signals_array.sum(axis=-1, keepdims=True)
    expected_dist_l0 = T.tensordot(language_l, distances, axes=[[2],[0]])

    unnorm_l0 = T.exp(choice_alphas[:,np.newaxis,np.newaxis,np.newaxis]*-expected_dist_l0)
    shape = unnorm_l0.shape
    _, picsize_index, _, state_index = T.mgrid[0:shape[0], 0:shape[1], 0:shape[2], 0:shape[3]]
    unnorm_l0 = T.switch(state_index > picsize_index, 0, unnorm_l0)
    l0 = unnorm_l0 / T.sum(unnorm_l0, axis=-1, keepdims=True)

    l0_extended = l0[:,:,np.newaxis,:,:]

    costs_possible = T.outer(cost_factors, objective_costs_possible)
    utility_l0 = T.log(l0_extended)
    unnorm_s1 = T.exp(
        alphas[:,np.newaxis, np.newaxis, np.newaxis, np.newaxis] *
        (utility_l0 - costs_possible[:,np.newaxis,np.newaxis,:,np.newaxis])
    )
    unnorm_s1 = unnorm_s1 * considered_signals[np.newaxis,np.newaxis,:,:,np.newaxis]
    s1 = unnorm_s1 / unnorm_s1.sum(axis=-2, keepdims=True)
    s1 = T.switch(T.isnan(s1), 0., s1)

    l2 = s1 / s1.sum(axis=-1, keepdims=True)
    expected_dist_l2 = T.tensordot(l2, distances, axes=[[4],[0]])

    unnorm_l2 = T.exp(choice_alphas[:,np.newaxis,np.newaxis,np.newaxis,np.newaxis]*-expected_dist_l2)
    shape = unnorm_l2.shape
    _, picsize_index, _, _, state_index = T.mgrid[
        0:shape[0], 
        0:shape[1], 
        0:shape[2], 
        0:shape[3], 
        0:shape[4]
    ]
    unnorm_l2 = T.switch(state_index > picsize_index, 0, unnorm_l2)
    l2 = unnorm_l2 / T.sum(unnorm_l2, axis=-1, keepdims=True)

    l2_language = l2[:,:,T.arange(real_signals_indices.shape[0]), real_signals_indices,:].squeeze()
    costs_real = T.outer(cost_factors, objective_costs_real)
    utility_l2 = T.log(l2_language)

    unnorm_s3 = T.exp(
        alphas[:,np.newaxis,np.newaxis, np.newaxis]*
        (utility_l2 - costs_real[:,np.newaxis,:,np.newaxis])
    )

    s3 = unnorm_s3 / unnorm_s3.sum(axis=-2, keepdims=True)
    s3 = T.switch(T.isnan(s3), 0, s3)

    return s3

(And the associated functions for producing the data and running the model:)

 def produce_artificial_data(num_states, num_participants):
    num_states = 20

    # (possible signals, states)
    states = np.arange(num_states)

    idx_picsize, idx_state = np.mgrid[0:num_states, 0:num_states]
    picsize_geq_states = (idx_picsize >= idx_state)[None,:,:]  # boolean mark for fine ones
    # proportions_array has shape (# states, # states)
    # first dimension indicates the size of the picture, 
    # second dimension indicates the # objects in target set
    proportions_array = idx_state / idx_picsize

    #### for each signal, specify name, array, cost, type
    # make sure that the name is the same as the name recorded in the experimental data
    #### signals in experiment
    # for now, wrt types: 1=upward monotonic, 2=point, 3=downward monotonic
    most = ("Most", proportions_array > 0.5, 1, 1)
    mth = ("More than half", proportions_array > 0.5, 3, 1)
    every = ("All", proportions_array == 1., 1, 1)
    half = ("Half",  np.abs(proportions_array-0.5) == np.min( np.abs(proportions_array-0.5), axis=1, keepdims=True), 1, 1)  
    many = ("Many", proportions_array > 0.4, 1, 1)
    none = ("None", proportions_array == 0., 1, 3)
    lth = ("Less than half", proportions_array < 0.5, 3, 3)
    few = ("Few", proportions_array < 0.2, 1, 3)
    some = ("Some", proportions_array > 0., 1, 1)

    # signals that I am assuming the speaking thinks of the listener thinking about,
    # but are not real options in the experiment
    more_than_2_3rds = ("more_than_2_3rds", proportions_array > 0.66, 3, 1)
    more_than_3_4ths = ("more_than_3_4ths", proportions_array > 0.75, 3, 1)
    less_than_1_3rds = ("less_than_1_3rds", proportions_array < 0.33, 3, 3)
    less_than_1_4ths = ("less_than_1_4ths", proportions_array < 0.25, 3, 3)

    signals = [
        most, mth, every, many, 
        half,
        none, lth, few, some, 
        more_than_2_3rds, more_than_3_4ths, 
        less_than_1_3rds, less_than_1_4ths
    ]
    real_signals_indices = np.arange(9)

    names = [signal[0] for signal in signals]
    # signals_language_array_unrestricted doesn't exclude the states > picsize yet
    signals_language_array_unrestricted = np.array([s[1] for s in signals], dtype="int")
    possible_signals_array = np.transpose(signals_language_array_unrestricted & picsize_geq_states, axes=(1,0,2))
    costs_possible = np.array([s[2] for s in signals])
    types = np.array([s[3] for s in signals])

    # (real signals)
    costs_real = costs_possible[real_signals_indices]

    # (real signals, possible signals)
    at_most_as_costly = ((costs_real.reshape(-1,1) - costs_possible) >= 0).astype(int)

    # (states, states)
    a = np.tile(np.arange(num_states), reps=(num_states,1))
    distances = np.abs(a - np.arange(num_states).reshape(-1,1))

    num_participants = 10

    # (participants)
    # alphas = np.array([0.11, 0.22])
    alphas = np.repeat(np.exp([-0.22579135]), repeats=num_participants)

    # (participants)
    # choice_alphas = np.array([5.61, 1.])
    choice_alphas = np.repeat(np.exp([-0.22579135]), repeats=num_participants)

    # (participants)
    # cost_factors = np.array([1.01, 2.])
    cost_factors = np.repeat(np.exp([-0.22579135]), repeats=num_participants)

    num_trials = 10

    s3 = theano_RSA(return_symbolic=False)(
        possible_signals_array, real_signals_indices, 
        alphas, choice_alphas, cost_factors, costs_possible, 
        at_most_as_costly, types, distances
    )

    ### produce participant information
    participants_indices = np.repeat(np.arange(num_participants), repeats=num_trials)

    # size of the picture for each trial
    picsizes_values = np.random.choice(np.arange(1,num_states), size=num_participants*num_trials)

    # the number of objects in the target set (<= picsize)
    states_values = np.apply_along_axis(
        lambda x: np.random.choice(np.arange(x)),
        arr=picsizes_values.reshape(-1,1)+1,
        axis=1
    )

    # probability of accepting 
    # shape: (# datapoints, # real signals)
    probs_accept = np.stack([
            s3[participant_index, picsize_value, :, state_value]
            for picsize_value, participant_index, state_value in zip(picsizes_values, participants_indices, states_values)
    ])

    # TODO: make this more elegant. Not so nice now. 
    # unfortunately np.multinomial doesn't accept an array for p parameter
    picked_signals_indices = np.apply_along_axis(
        lambda p: np.random.choice(len(p), p=p),
        arr=probs_accept,
        axis=1
    )

    return {
        "num_participants": num_participants, 
        "num_states": num_states, 
        "possible_signals_array": possible_signals_array, 
        "real_signals_indices": real_signals_indices, 
        "costs_possible": costs_possible, 
        "at_most_as_costly": at_most_as_costly, 
        "types": types, 
        "distances": distances,
        "picked_signals_indices": picked_signals_indices, 
        "picsizes_values": picsizes_values,
        "participants_indices": participants_indices,
        "states_values": states_values
    }

def define_model(num_participants, num_states, possible_signals_array, 
                 real_signals_indices, costs_possible, at_most_as_costly, types, distances,
                 picked_signals_indices, picsizes_values, participants_indices, states_values):

    with pm.Model() as model:

        ### hyperprior for population-level parameters over alphas
        pop_alpha_mu = pm.HalfNormal("pop_alpha_mu", sigma=1)
        pop_alpha_sigma = pm.HalfNormal("pop_alpha_sigma", sigma=1)
        alphas = pm.Gamma("alphas", mu=pop_alpha_mu, sigma=pop_alpha_sigma, shape=num_participants)
    #     alphas = pm.Uniform("alphas", lower=0.0, upper=1, shape=num_participants)

        ### hyperprior for population-level parameters over choice_alphas
        pop_choice_alpha_mu = pm.HalfNormal("pop_choice_alpha_mu", sigma=1)
        pop_choice_alpha_sigma = pm.HalfNormal("pop_choice_alpha_sigma", sigma=1)
        choice_alphas = pm.Gamma(
            "choice_alphas", 
            mu=pop_choice_alpha_mu, 
            sigma=pop_choice_alpha_sigma, 
            shape=num_participants
        )
    #     choice_alphas = pm.Uniform("choice_alphas", lower=0.0, upper=1, shape=num_participants)

        ### hyperprior for population-level parameters over cost_factors
        pop_cost_factors_mu = pm.HalfNormal("pop_cost_factors_mu", sigma=1)
        pop_cost_factors_sigma = pm.HalfNormal("pop_cost_factors_sigma", sigma=1)
        cost_factors = pm.Gamma(
            "cost_factors", 
            mu=pop_cost_factors_mu, 
            sigma=pop_cost_factors_sigma, 
            shape=num_participants
        )
    #     cost_factors = pm.Uniform("cost_factors", lower=0.0, upper=1, shape=num_participants)

    #     alphas_print = tt.printing.Print('alphas')(alphas)
        choice_alphas_print = tt.printing.Print('choice_alphas')(choice_alphas)
        cost_factors_print = tt.printing.Print('cost_factors')(cost_factors)

        ### TODO: hyperprior for population-level parameters over production error

        ### RSA model
        # s3: (participant, picsize, real_signal, state)
        s3 = theano_RSA(
            possible_signals_array=tt.shared(possible_signals_array, name="possible_signals_array"), 
            real_signals_indices=tt.shared(real_signals_indices, name="real_signals_indices"), 
            objective_costs_possible=tt.shared(costs_possible, name="objective_costs_possible"), 
            at_most_as_costly=tt.shared(at_most_as_costly, name="at_most_as_costly"), 
            types=tt.shared(types, name="types"),
            distances=tt.shared(distances, name="distances"),
            alphas=alphas, 
            choice_alphas=choice_alphas, 
            cost_factors=cost_factors,
            return_symbolic=True
        )

        probabilities_accept_pre_error = pm.Deterministic(
            "p_accept_pre_error", s3[participants_indices, picsizes_values, :, states_values]
        )

        ### apply the error to the production probability
        # TODO: add noise to categorical
        probability_accept = probabilities_accept_pre_error

#         prob_accept_print = tt.printing.Print('prob_accept')(probability_accept)

        ### observed
        obs = pm.Categorical(
            "picked_signals", 
            p=probability_accept, 
            shape=len(picked_signals_indices), 
            observed=picked_signals_indices
        )

#         for RV in model.basic_RVs:
#             print(RV.name, RV.logp(model.test_point))

    return model

tt.config.compute_test_value = 'ignore'

define_model_input = produce_artificial_data(5, 10)
model = define_model(**define_model_input)

with model:

    # step=pm.Metropolis()
    step=pm.NUTS()

    trace=pm.sample(
        step=step,
        cores=1
    )

The model samples with the simulated data until it’s about 10% of the way through, and then raises the following error:

ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `cost_factors_log__`.ravel()[0] is zero.
The derivative of RV `cost_factors_log__`.ravel()[1] is zero.
The derivative of RV `cost_factors_log__`.ravel()[2] is zero.
The derivative of RV `cost_factors_log__`.ravel()[3] is zero.
The derivative of RV `cost_factors_log__`.ravel()[4] is zero.
The derivative of RV `cost_factors_log__`.ravel()[5] is zero.
The derivative of RV `cost_factors_log__`.ravel()[6] is zero.
The derivative of RV `cost_factors_log__`.ravel()[7] is zero.
The derivative of RV `cost_factors_log__`.ravel()[8] is zero.
The derivative of RV `cost_factors_log__`.ravel()[9] is zero.
The derivative of RV `pop_cost_factors_sigma_log__`.ravel()[0] is zero.
The derivative of RV `pop_cost_factors_mu_log__`.ravel()[0] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[0] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[1] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[2] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[3] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[4] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[5] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[6] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[7] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[8] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[9] is zero.
The derivative of RV `pop_choice_alpha_sigma_log__`.ravel()[0] is zero.
The derivative of RV `pop_choice_alpha_mu_log__`.ravel()[0] is zero.
The derivative of RV `alphas_log__`.ravel()[0] is zero.
The derivative of RV `alphas_log__`.ravel()[1] is zero.
The derivative of RV `alphas_log__`.ravel()[2] is zero.
The derivative of RV `alphas_log__`.ravel()[3] is zero.
The derivative of RV `alphas_log__`.ravel()[4] is zero.
The derivative of RV `alphas_log__`.ravel()[5] is zero.
The derivative of RV `alphas_log__`.ravel()[6] is zero.
The derivative of RV `alphas_log__`.ravel()[7] is zero.
The derivative of RV `alphas_log__`.ravel()[8] is zero.
The derivative of RV `alphas_log__`.ravel()[9] is zero.
The derivative of RV `pop_alpha_sigma_log__`.ravel()[0] is zero.
The derivative of RV `pop_alpha_mu_log__`.ravel()[0] is zero.

I could trace back the problem to the fact that the dlogp has a bunch of NaN values:

dlogp:  [         nan          nan          nan          nan          nan
      nan          nan          nan          nan          nan
 -11.18093307  11.90769353          nan          nan          nan
      nan          nan          nan          nan          nan
      nan          nan -11.18093307  11.90769353          nan
      nan          nan          nan          nan          nan
      nan          nan          nan          nan -11.18093307 11.90769353]

I figured that the reason for all these NaNs in the dlogp was the switch statement, and indeed without the switch (i.e. the model in the initial post), the dlogp is defined and NUTS seems to work fine. Do you think that something else might be causing these NaN values in the dlogp? Thank you for your help!