Complex model samples very slowly - wondering about possible improvements

Hi everyone,

Thank you for the amazing work you’ve been doing! I am in the process of writing a model for fitting some experimental data in a behavioural experiment. The sampling is extremely slow and the model is quite complicated, so I was wondering if there is some way to make it more efficient.

import pymc3 as pm
import theano as tt
import theano.tensor as T
import numpy as np
import scipy.stats as stats


def theano_RSA(
        possible_signals_array=T.lmatrix("possible_signals_array"), 
        real_signals_indices=T.lvector("real_signals_indices"), 
        alphas=T.dvector("alphas"), 
        choice_alphas=T.dvector("choice_alphas"), 
        cost_factors= T.dvector("cost_factors"), 
        objective_costs_possible=T.dvector("objective_costs_possible"), 
        at_most_as_costly=T.lmatrix("at_most_as_costly"), 
        types=T.lvector("types"),
        distances=T.dmatrix("distances")):
    """
    Parameters
    ----------
    possible_signals_array: 
        Shape (# possible signals, # states)
    real_signals_indices: 
        the indices of the real signals in the possible_signals_array. 
        shape (# real signals)
    alphas: 
        shape (# participants)
    choice_alphas: 
        shape (# participants)
    cost_factors: 
        How much each participant weight the cost of the messages in production.
        shape (# participants)
    objective_costs_possible: 
        The "objective" cost of the signals
        shape (# signals)
    at_most_as_costly:
        for each real signal, whether each possible signal is at most as costly as the real signal
        shape (# real signals available to participant, # signals considered by listener). 
        i,j is 1 iff cost(i) >= cost(j)
    distances:
        shape (# states, # states)
    Returns
    -------
    theano variable
    """
    real_signals_array = possible_signals_array[real_signals_indices]
    objective_costs_real = objective_costs_possible[real_signals_indices]
    considered_signals = at_most_as_costly & types.dimshuffle("x", 0)

    language_l = possible_signals_array / possible_signals_array.sum(axis=-1, keepdims=True)
    expected_dist_l0 = T.dot(language_l, distances)

    unnorm_l0 = T.exp(choice_alphas[:,np.newaxis,np.newaxis]*-expected_dist_l0)
    l0 = unnorm_l0 / T.sum(unnorm_l0, axis=-1, keepdims=True)

    l0_extended = l0[:,np.newaxis,:,:]

    unnorm_s1 = T.exp(
        alphas[:,np.newaxis, np.newaxis, np.newaxis]*
        (utility_l0 - costs_possible[:,np.newaxis,:,np.newaxis])
    )

    unnorm_l2 = T.exp(choice_alphas[:,np.newaxis,np.newaxis,np.newaxis]*-expected_dist_l2)
    l2 = unnorm_l2 / T.sum(unnorm_l2, axis=-1, keepdims=True)

    unnorm_s3 = T.exp(
        alphas[:,np.newaxis, np.newaxis]*
        (utility_l2 - costs_real[:,:,np.newaxis])
    )

    s3 = unnorm_s3 / unnorm_s3.sum(axis=-2, keepdims=True)

    return s3


def create_model(num_participants, num_states, possible_signals_array, real_signals_indices, costs_possible, 
                 at_most_as_costly, types, distances, picked_signals_indices, picsizes_values, participants_indices, states_values):

    states = np.arange(num_states)

    with pm.Model() as model:

        ### hyperprior for population-level parameters over alphas
        pop_alpha_mu = pm.HalfNormal("pop_alpha_mu", sigma=1)
        pop_alpha_sigma = pm.HalfNormal("pop_alpha_sigma", sigma=1)
        alphas = pm.Gamma("alphas", mu=pop_alpha_mu, sigma=pop_alpha_sigma, shape=num_participants)

        ### hyperprior for population-level parameters over choice_alphas
        pop_choice_alpha_mu = pm.HalfNormal("pop_choice_alpha_mu", sigma=1)
        pop_choice_alpha_sigma = pm.HalfNormal("pop_choice_alpha_sigma", sigma=1)
        choice_alphas = pm.Gamma("choice_alphas", mu=pop_choice_alpha_mu, sigma=pop_choice_alpha_sigma, shape=num_participants)

        ### hyperprior for population-level parameters over cost_factors
        pop_cost_factors_mu = pm.HalfNormal("pop_cost_factors_mu", sigma=1)
        pop_cost_factors_sigma = pm.HalfNormal("pop_cost_factors_sigma", sigma=1)
        cost_factors = pm.Gamma("cost_factors", mu=pop_cost_factors_mu, sigma=pop_cost_factors_sigma, shape=num_participants)

        s3_list = [[],]
        for state in states[1:]:

            s3 = theano_RSA(
                possible_signals_array=tt.shared(possible_signals_array[state], name="possible_signals_array"), 
                real_signals_indices=tt.shared(real_signals_indices, name="real_signals_indices"), 
                objective_costs_possible=tt.shared(costs_possible, name="objective_costs_possible"), 
                at_most_as_costly=tt.shared(at_most_as_costly, name="at_most_as_costly"), 
                types=tt.shared(types, name="types"),
                distances=tt.shared(distances[:state+1,:state+1], name="distances"),
                alphas=alphas, 
                choice_alphas=choice_alphas, 
                cost_factors=cost_factors
            )

            s3_list.append(s3)

        probability_accept = T.stack([
            s3_list[picsize_value][participant_index, :, state_value]
            for picsize_value, participant_index, state_value in zip(picsizes_values, participants_indices, states_values)
        ])

        # save the probability of acceptance
        pm.Deterministic("probability_accept", probability_accept)

        ### observed
        obs = pm.Categorical(
            "picked_signals", 
            p=probability_accept, 
            shape=len(picsizes_values), 
    #         observed=picked_signals_indices
        )

    return model


def run_model(model, cores, draws=1000, tune=1000):
    with model:
        step = pm.NUTS(target_accept=0.99)
        trace=pm.sample(
            cores=cores,
            step=step,
            draws=draws,
            tune=tune
        )
    return trace

As you can see, the model itself is not very complex at all; most of the complexity is coming from the theano_RSA function.

One obvious way to make sampling more efficient would be to eliminate the loop in the the model which adds to s3_list. Initially I had a fully vectorised function, however since the different elements have different sizes in one dimension, the RSA function creates a bunch of NaN values, which broke the theano gradient calculation, and therefore didn’t work with NUTS. I tried setting the NaNs to 0 with switch, but that also broke the gradient.

Part of the problem is that the involved tensors are going to be pretty large (e.g. shape (120, 13, 13, 20)), and there’s multiple such tensors in the RSA function. I tried using variational inference because afaik it’s a faster option, but it wouldn’t work because of the observed categorical distribution at the end.

Please let me know if more detail on the various variables would be helpful!

PS. Here is some additional functions needed to run the model, but which are not involved in the actual sampling so don’t really matter for pymc3:

def produce_artificial_data(num_participants, num_states, num_trials):
    ### produce language information
    states = np.arange(num_states)
    possible_signals_array, real_signals_indices, costs_possible, at_most_as_costly, types = produce_experiment_language(states)
    distances = create_distance(num_states)

    ### produce participant information
    participants_indices = np.repeat(np.arange(num_participants), repeats=num_trials)

    # size of the picture for each trial
    picsizes_values = np.random.choice(np.arange(1,num_states), size=num_participants*num_trials)

    # the number of objects in the target set (<= picsize)
    states_values = np.apply_along_axis(
        lambda x: np.random.choice(np.arange(x)),
        arr=picsizes_values.reshape(-1,1)+1,
        axis=1
    )

    alphas = stats.gamma.rvs(a=1.3, scale=1/1.3, size=num_participants)
    choice_alphas = stats.gamma.rvs(a=3., scale=1/2.4, size=num_participants)
    cost_factors = stats.gamma.rvs(a=3., scale=1/2.4, size=num_participants)

    # WARNING: this might create a HUGE array, so be careful with the # of participants parameter
    s3_list = [[],]
    for state in states[1:]:

        s3=theano_RSA(return_symbolic=False)(**{
            "possible_signals_array": possible_signals_array[state], 
            "real_signals_indices": real_signals_indices, 
            "alphas": alphas, 
            "choice_alphas": choice_alphas, 
            "cost_factors": cost_factors, 
            "objective_costs_possible": costs_possible, 
            "at_most_as_costly": at_most_as_costly, 
            "types": types,
            "distances": distances[:state+1, :state+1]
        })

        s3_list.append(s3)

    # probability of accepting 
    # shape: (# datapoints, # real signals)
    probs_accept = np.stack([
            s3_list[picsize_value][participant_index, :, state_value]
            for picsize_value, participant_index, state_value in zip(picsizes_values, participants_indices, states_values)
    ])

    # TODO: make this more elegant. Not so nice now. 
    # unfortunately np.multinomial doesn't accept an array for p parameter
    picked_signals_indices = np.apply_along_axis(
        lambda p: np.random.choice(len(p), p=p),
        arr=probs_accept,
        axis=1
    )

    return {
        "num_participants": num_participants, 
        "num_states": num_states, 
        "possible_signals_array": possible_signals_array, 
        "real_signals_indices": real_signals_indices, 
        "costs_possible": costs_possible, 
        "at_most_as_costly": at_most_as_costly, 
        "types": types, 
        "distances": distances,
        "picked_signals_indices": picked_signals_indices, 
        "picsizes_values": picsizes_values,
        "participants_indices": participants_indices,
        "states_values": states_values
    }, {"alphas": alphas, "choice_alphas": choice_alphas, "cost_factors": cost_factors}


def produce_experiment_language(states):
    """
    Produce various information concerning the language used by participants in experiment
    The important thing about this function is that it depends on data only through states,
    which is the maximum number of states. Conceptually convenient to separate from data.
    """

    # proportions_array has shape (# states, # states)
    proportions_arrays = [np.linspace(0, 1, num_objects) for num_objects in states+1]

    #### for each signal, specify name, array, cost, type
    # make sure that the name is the same as the name recorded in the experimental data
    #### signals in experiment
    # for now, wrt types: 1=upward monotonic, 2=point, 3=downward monotonic
    most = ("Most", [p > 0.5 for p in proportions_arrays], 1, 1)
    mth = ("More than half", [p > 0.5 for p in proportions_arrays], 3, 1)
    every = ("All", [p == 1. for p in proportions_arrays], 1, 1)
    # approximate "half" for the cases where there is no perfect mid?
    half = ("Half",  [np.abs(p-0.5) == np.min(np.abs(p-0.5)) if p.size!=0 else np.array([]) for p in proportions_arrays], 1, 1)  
    many = ("Many", [p > 0.4 for p in proportions_arrays], 1, 1)
    none = ("None", [p == 0. for p in proportions_arrays], 1, 3)
    lth = ("Less than half", [p < 0.5 for p in proportions_arrays], 3, 3)
    few = ("Few", [p < 0.2 for p in proportions_arrays], 1, 3)
    some = ("Some", [p > 0. for p in proportions_arrays], 1, 1)

    # signals that I am assuming the speaking thinks of the listener thinking about,
    # but are not real options in the experiment
    more_than_2_3rds = ("more_than_2_3rds", [p > 0.66 for p in proportions_arrays], 3, 1)
    more_than_3_4ths = ("more_than_3_4ths", [p > 0.75 for p in proportions_arrays], 3, 1)
    less_than_1_3rds = ("less_than_1_3rds", [p < 0.33 for p in proportions_arrays], 3, 3)
    less_than_1_4ths = ("less_than_1_4ths", [p < 0.25 for p in proportions_arrays], 3, 3)

    signals = [
        most, mth, every, half, many, none, lth, few, some, 
        more_than_2_3rds, more_than_3_4ths, 
        less_than_1_3rds, less_than_1_4ths
    ]
    real_signals_indices = np.arange(9).astype(int)

    names = [signal[0] for signal in signals]
    # signals_language_array for each picsize has an array with shape (# signals, picsize)
    possible_signals_array = [
        np.array(a).astype(int) 
        for a in list(zip(*[s[1] for s in signals]))
    ]
    costs_possible = np.array([s[2] for s in signals])
    types = np.array([s[3] for s in signals])

    # (real signals)
    costs_real = costs_possible[real_signals_indices]

    # (real signals, possible signals)
    at_most_as_costly = ((costs_real.reshape(-1,1) - costs_possible) >= 0).astype(int)

    return possible_signals_array, real_signals_indices, costs_possible, at_most_as_costly, types

def create_distance(num_states):
    # (states, states)
    a = np.tile(np.arange(num_states), reps=(num_states,1))
    distances = np.abs(a - np.arange(num_states).reshape(-1,1))
    return distances

num_participants = 10    # will be 120
num_states = 5    # will be will be 20
num_trials = 10    # will be 300

model_input, real_parameters =  produce_artificial_data(
        num_participants,
        num_states,
        num_trials
)

model = create_model(**model_input)
trace = run_model(model, 1)

I think you have the right intuition with your attempt to create a vectorized function. I am not sure why switch would break the gradients. Did you write something like x = T.switch(T.isnan(x), 0., x) ? Do you think you could flatten along the ragged dimension? I know it can be a headache to keep track of the start/stop indices of different pieces of that vector but it could help you get the computations sped up.

Also, a more basic thing to check out might be where divergences occur. I notice that you have a quite high target accept rate. If you relax that and check out where the divergences occur, you may uncover the source of troublesome geometry that requires lots of leapfrog steps in NUTS to get accepted samples. The scatter plots in this notebook may be a useful example.

Thank you for your answer! I have indeed written something like what you suggested:

def theano_RSA(
        possible_signals_array=T.ltensor3("possible_signals_array"), 
        real_signals_indices=T.lvector("real_signals_indices"), 
        alphas=T.dvector("alphas"), 
        choice_alphas=T.dvector("choice_alphas"), 
        cost_factors= T.dvector("cost_factors"), 
        objective_costs_possible=T.dvector("objective_costs_possible"), 
        at_most_as_costly=T.lmatrix("at_most_as_costly"), 
        types=T.lvector("types"),
        distances=T.dmatrix("distances")):

    real_signals_array = possible_signals_array[:,real_signals_indices]
    objective_costs_real = objective_costs_possible[real_signals_indices]

    max_pic_size = possible_signals_array.shape[0] - 1
    considered_signals = at_most_as_costly & types.dimshuffle("x", 0)

    language_l = possible_signals_array / possible_signals_array.sum(axis=-1, keepdims=True)
    expected_dist_l0 = T.tensordot(language_l, distances, axes=[[2],[0]])

    unnorm_l0 = T.exp(choice_alphas[:,np.newaxis,np.newaxis,np.newaxis]*-expected_dist_l0)
    shape = unnorm_l0.shape
    _, picsize_index, _, state_index = T.mgrid[0:shape[0], 0:shape[1], 0:shape[2], 0:shape[3]]
    unnorm_l0 = T.switch(state_index > picsize_index, 0, unnorm_l0)
    l0 = unnorm_l0 / T.sum(unnorm_l0, axis=-1, keepdims=True)

    l0_extended = l0[:,:,np.newaxis,:,:]

    costs_possible = T.outer(cost_factors, objective_costs_possible)
    utility_l0 = T.log(l0_extended)
    unnorm_s1 = T.exp(
        alphas[:,np.newaxis, np.newaxis, np.newaxis, np.newaxis] *
        (utility_l0 - costs_possible[:,np.newaxis,np.newaxis,:,np.newaxis])
    )
    unnorm_s1 = unnorm_s1 * considered_signals[np.newaxis,np.newaxis,:,:,np.newaxis]
    s1 = unnorm_s1 / unnorm_s1.sum(axis=-2, keepdims=True)
    s1 = T.switch(T.isnan(s1), 0., s1)

    l2 = s1 / s1.sum(axis=-1, keepdims=True)
    expected_dist_l2 = T.tensordot(l2, distances, axes=[[4],[0]])

    unnorm_l2 = T.exp(choice_alphas[:,np.newaxis,np.newaxis,np.newaxis,np.newaxis]*-expected_dist_l2)
    shape = unnorm_l2.shape
    _, picsize_index, _, _, state_index = T.mgrid[
        0:shape[0], 
        0:shape[1], 
        0:shape[2], 
        0:shape[3], 
        0:shape[4]
    ]
    unnorm_l2 = T.switch(state_index > picsize_index, 0, unnorm_l2)
    l2 = unnorm_l2 / T.sum(unnorm_l2, axis=-1, keepdims=True)

    l2_language = l2[:,:,T.arange(real_signals_indices.shape[0]), real_signals_indices,:].squeeze()
    costs_real = T.outer(cost_factors, objective_costs_real)
    utility_l2 = T.log(l2_language)

    unnorm_s3 = T.exp(
        alphas[:,np.newaxis,np.newaxis, np.newaxis]*
        (utility_l2 - costs_real[:,np.newaxis,:,np.newaxis])
    )

    s3 = unnorm_s3 / unnorm_s3.sum(axis=-2, keepdims=True)
    s3 = T.switch(T.isnan(s3), 0, s3)

    return s3

(And the associated functions for producing the data and running the model:)

 def produce_artificial_data(num_states, num_participants):
    num_states = 20

    # (possible signals, states)
    states = np.arange(num_states)

    idx_picsize, idx_state = np.mgrid[0:num_states, 0:num_states]
    picsize_geq_states = (idx_picsize >= idx_state)[None,:,:]  # boolean mark for fine ones
    # proportions_array has shape (# states, # states)
    # first dimension indicates the size of the picture, 
    # second dimension indicates the # objects in target set
    proportions_array = idx_state / idx_picsize

    #### for each signal, specify name, array, cost, type
    # make sure that the name is the same as the name recorded in the experimental data
    #### signals in experiment
    # for now, wrt types: 1=upward monotonic, 2=point, 3=downward monotonic
    most = ("Most", proportions_array > 0.5, 1, 1)
    mth = ("More than half", proportions_array > 0.5, 3, 1)
    every = ("All", proportions_array == 1., 1, 1)
    half = ("Half",  np.abs(proportions_array-0.5) == np.min( np.abs(proportions_array-0.5), axis=1, keepdims=True), 1, 1)  
    many = ("Many", proportions_array > 0.4, 1, 1)
    none = ("None", proportions_array == 0., 1, 3)
    lth = ("Less than half", proportions_array < 0.5, 3, 3)
    few = ("Few", proportions_array < 0.2, 1, 3)
    some = ("Some", proportions_array > 0., 1, 1)

    # signals that I am assuming the speaking thinks of the listener thinking about,
    # but are not real options in the experiment
    more_than_2_3rds = ("more_than_2_3rds", proportions_array > 0.66, 3, 1)
    more_than_3_4ths = ("more_than_3_4ths", proportions_array > 0.75, 3, 1)
    less_than_1_3rds = ("less_than_1_3rds", proportions_array < 0.33, 3, 3)
    less_than_1_4ths = ("less_than_1_4ths", proportions_array < 0.25, 3, 3)

    signals = [
        most, mth, every, many, 
        half,
        none, lth, few, some, 
        more_than_2_3rds, more_than_3_4ths, 
        less_than_1_3rds, less_than_1_4ths
    ]
    real_signals_indices = np.arange(9)

    names = [signal[0] for signal in signals]
    # signals_language_array_unrestricted doesn't exclude the states > picsize yet
    signals_language_array_unrestricted = np.array([s[1] for s in signals], dtype="int")
    possible_signals_array = np.transpose(signals_language_array_unrestricted & picsize_geq_states, axes=(1,0,2))
    costs_possible = np.array([s[2] for s in signals])
    types = np.array([s[3] for s in signals])

    # (real signals)
    costs_real = costs_possible[real_signals_indices]

    # (real signals, possible signals)
    at_most_as_costly = ((costs_real.reshape(-1,1) - costs_possible) >= 0).astype(int)

    # (states, states)
    a = np.tile(np.arange(num_states), reps=(num_states,1))
    distances = np.abs(a - np.arange(num_states).reshape(-1,1))

    num_participants = 10

    # (participants)
    # alphas = np.array([0.11, 0.22])
    alphas = np.repeat(np.exp([-0.22579135]), repeats=num_participants)

    # (participants)
    # choice_alphas = np.array([5.61, 1.])
    choice_alphas = np.repeat(np.exp([-0.22579135]), repeats=num_participants)

    # (participants)
    # cost_factors = np.array([1.01, 2.])
    cost_factors = np.repeat(np.exp([-0.22579135]), repeats=num_participants)

    num_trials = 10

    s3 = theano_RSA(return_symbolic=False)(
        possible_signals_array, real_signals_indices, 
        alphas, choice_alphas, cost_factors, costs_possible, 
        at_most_as_costly, types, distances
    )

    ### produce participant information
    participants_indices = np.repeat(np.arange(num_participants), repeats=num_trials)

    # size of the picture for each trial
    picsizes_values = np.random.choice(np.arange(1,num_states), size=num_participants*num_trials)

    # the number of objects in the target set (<= picsize)
    states_values = np.apply_along_axis(
        lambda x: np.random.choice(np.arange(x)),
        arr=picsizes_values.reshape(-1,1)+1,
        axis=1
    )

    # probability of accepting 
    # shape: (# datapoints, # real signals)
    probs_accept = np.stack([
            s3[participant_index, picsize_value, :, state_value]
            for picsize_value, participant_index, state_value in zip(picsizes_values, participants_indices, states_values)
    ])

    # TODO: make this more elegant. Not so nice now. 
    # unfortunately np.multinomial doesn't accept an array for p parameter
    picked_signals_indices = np.apply_along_axis(
        lambda p: np.random.choice(len(p), p=p),
        arr=probs_accept,
        axis=1
    )

    return {
        "num_participants": num_participants, 
        "num_states": num_states, 
        "possible_signals_array": possible_signals_array, 
        "real_signals_indices": real_signals_indices, 
        "costs_possible": costs_possible, 
        "at_most_as_costly": at_most_as_costly, 
        "types": types, 
        "distances": distances,
        "picked_signals_indices": picked_signals_indices, 
        "picsizes_values": picsizes_values,
        "participants_indices": participants_indices,
        "states_values": states_values
    }

def define_model(num_participants, num_states, possible_signals_array, 
                 real_signals_indices, costs_possible, at_most_as_costly, types, distances,
                 picked_signals_indices, picsizes_values, participants_indices, states_values):

    with pm.Model() as model:

        ### hyperprior for population-level parameters over alphas
        pop_alpha_mu = pm.HalfNormal("pop_alpha_mu", sigma=1)
        pop_alpha_sigma = pm.HalfNormal("pop_alpha_sigma", sigma=1)
        alphas = pm.Gamma("alphas", mu=pop_alpha_mu, sigma=pop_alpha_sigma, shape=num_participants)
    #     alphas = pm.Uniform("alphas", lower=0.0, upper=1, shape=num_participants)

        ### hyperprior for population-level parameters over choice_alphas
        pop_choice_alpha_mu = pm.HalfNormal("pop_choice_alpha_mu", sigma=1)
        pop_choice_alpha_sigma = pm.HalfNormal("pop_choice_alpha_sigma", sigma=1)
        choice_alphas = pm.Gamma(
            "choice_alphas", 
            mu=pop_choice_alpha_mu, 
            sigma=pop_choice_alpha_sigma, 
            shape=num_participants
        )
    #     choice_alphas = pm.Uniform("choice_alphas", lower=0.0, upper=1, shape=num_participants)

        ### hyperprior for population-level parameters over cost_factors
        pop_cost_factors_mu = pm.HalfNormal("pop_cost_factors_mu", sigma=1)
        pop_cost_factors_sigma = pm.HalfNormal("pop_cost_factors_sigma", sigma=1)
        cost_factors = pm.Gamma(
            "cost_factors", 
            mu=pop_cost_factors_mu, 
            sigma=pop_cost_factors_sigma, 
            shape=num_participants
        )
    #     cost_factors = pm.Uniform("cost_factors", lower=0.0, upper=1, shape=num_participants)

    #     alphas_print = tt.printing.Print('alphas')(alphas)
        choice_alphas_print = tt.printing.Print('choice_alphas')(choice_alphas)
        cost_factors_print = tt.printing.Print('cost_factors')(cost_factors)

        ### TODO: hyperprior for population-level parameters over production error

        ### RSA model
        # s3: (participant, picsize, real_signal, state)
        s3 = theano_RSA(
            possible_signals_array=tt.shared(possible_signals_array, name="possible_signals_array"), 
            real_signals_indices=tt.shared(real_signals_indices, name="real_signals_indices"), 
            objective_costs_possible=tt.shared(costs_possible, name="objective_costs_possible"), 
            at_most_as_costly=tt.shared(at_most_as_costly, name="at_most_as_costly"), 
            types=tt.shared(types, name="types"),
            distances=tt.shared(distances, name="distances"),
            alphas=alphas, 
            choice_alphas=choice_alphas, 
            cost_factors=cost_factors,
            return_symbolic=True
        )

        probabilities_accept_pre_error = pm.Deterministic(
            "p_accept_pre_error", s3[participants_indices, picsizes_values, :, states_values]
        )

        ### apply the error to the production probability
        # TODO: add noise to categorical
        probability_accept = probabilities_accept_pre_error

#         prob_accept_print = tt.printing.Print('prob_accept')(probability_accept)

        ### observed
        obs = pm.Categorical(
            "picked_signals", 
            p=probability_accept, 
            shape=len(picked_signals_indices), 
            observed=picked_signals_indices
        )

#         for RV in model.basic_RVs:
#             print(RV.name, RV.logp(model.test_point))

    return model

tt.config.compute_test_value = 'ignore'

define_model_input = produce_artificial_data(5, 10)
model = define_model(**define_model_input)

with model:

    # step=pm.Metropolis()
    step=pm.NUTS()

    trace=pm.sample(
        step=step,
        cores=1
    )

The model samples with the simulated data until it’s about 10% of the way through, and then raises the following error:

ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `cost_factors_log__`.ravel()[0] is zero.
The derivative of RV `cost_factors_log__`.ravel()[1] is zero.
The derivative of RV `cost_factors_log__`.ravel()[2] is zero.
The derivative of RV `cost_factors_log__`.ravel()[3] is zero.
The derivative of RV `cost_factors_log__`.ravel()[4] is zero.
The derivative of RV `cost_factors_log__`.ravel()[5] is zero.
The derivative of RV `cost_factors_log__`.ravel()[6] is zero.
The derivative of RV `cost_factors_log__`.ravel()[7] is zero.
The derivative of RV `cost_factors_log__`.ravel()[8] is zero.
The derivative of RV `cost_factors_log__`.ravel()[9] is zero.
The derivative of RV `pop_cost_factors_sigma_log__`.ravel()[0] is zero.
The derivative of RV `pop_cost_factors_mu_log__`.ravel()[0] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[0] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[1] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[2] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[3] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[4] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[5] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[6] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[7] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[8] is zero.
The derivative of RV `choice_alphas_log__`.ravel()[9] is zero.
The derivative of RV `pop_choice_alpha_sigma_log__`.ravel()[0] is zero.
The derivative of RV `pop_choice_alpha_mu_log__`.ravel()[0] is zero.
The derivative of RV `alphas_log__`.ravel()[0] is zero.
The derivative of RV `alphas_log__`.ravel()[1] is zero.
The derivative of RV `alphas_log__`.ravel()[2] is zero.
The derivative of RV `alphas_log__`.ravel()[3] is zero.
The derivative of RV `alphas_log__`.ravel()[4] is zero.
The derivative of RV `alphas_log__`.ravel()[5] is zero.
The derivative of RV `alphas_log__`.ravel()[6] is zero.
The derivative of RV `alphas_log__`.ravel()[7] is zero.
The derivative of RV `alphas_log__`.ravel()[8] is zero.
The derivative of RV `alphas_log__`.ravel()[9] is zero.
The derivative of RV `pop_alpha_sigma_log__`.ravel()[0] is zero.
The derivative of RV `pop_alpha_mu_log__`.ravel()[0] is zero.

I could trace back the problem to the fact that the dlogp has a bunch of NaN values:

dlogp:  [         nan          nan          nan          nan          nan
      nan          nan          nan          nan          nan
 -11.18093307  11.90769353          nan          nan          nan
      nan          nan          nan          nan          nan
      nan          nan -11.18093307  11.90769353          nan
      nan          nan          nan          nan          nan
      nan          nan          nan          nan -11.18093307 11.90769353]

I figured that the reason for all these NaNs in the dlogp was the switch statement, and indeed without the switch (i.e. the model in the initial post), the dlogp is defined and NUTS seems to work fine. Do you think that something else might be causing these NaN values in the dlogp? Thank you for your help!

Unfortunately, I don’t have any further advice. Sorry! I do know that the switch function is not designed to break the gradient so there is something about that which can be addressed. Perhaps you could try narrowing it down to the Theano-only portions of your code and bringing it to their Github page.

Thank you, that was very useful! I tried to post the problem on stackoverflow with the theano tag, but alas nobody answered. Still, I did manage to isolate the problem to theano, namely to the fact that the Jacobian is NaN.

First calculating the Jacobian:

def theano_RSA(
        possible_signals_array=T.ltensor3("possible_signals_array"), 
        real_signals_indices=T.lvector("real_signals_indices"), 
        alphas=T.dvector("alphas"), 
        choice_alphas=T.dvector("choice_alphas"), 
        cost_factors= T.dvector("cost_factors"), 
        objective_costs_possible=T.dvector("objective_costs_possible"), 
        at_most_as_costly=T.lmatrix("at_most_as_costly"), 
        types=T.lvector("types"),
        distances=T.dmatrix("distances")):

    real_signals_array = possible_signals_array[:,real_signals_indices]
    objective_costs_real = objective_costs_possible[real_signals_indices]

    max_pic_size = possible_signals_array.shape[0] - 1

    considered_signals = at_most_as_costly & types.dimshuffle("x", 0)

    language_l = possible_signals_array / possible_signals_array.sum(axis=-1, keepdims=True)
    expected_dist_l0 = T.tensordot(language_l, distances, axes=[[2],[0]])

    unnorm_l0 = T.exp(choice_alphas[:,np.newaxis,np.newaxis,np.newaxis]*-expected_dist_l0)
    shape = unnorm_l0.shape
    _, picsize_index, _, state_index = T.mgrid[
        0:shape[0], 
        0:shape[1], 
        0:shape[2], 
        0:shape[3]
    ]
    unnorm_l0 = T.switch(state_index > picsize_index, 0, unnorm_l0)
    l0 = unnorm_l0 / T.sum(unnorm_l0, axis=-1, keepdims=True)

    l0_extended = l0[:,:,np.newaxis,:,:]

    costs_possible = T.outer(cost_factors, objective_costs_possible)
    utility_l0 = T.log(l0_extended)
    unnorm_s1 = T.exp(
        alphas[:,np.newaxis, np.newaxis, np.newaxis, np.newaxis] *
        (utility_l0 - costs_possible[:,np.newaxis,np.newaxis,:,np.newaxis])
    )

    unnorm_s1 = unnorm_s1 * considered_signals[np.newaxis,np.newaxis,:,:,np.newaxis]
    s1 = unnorm_s1 / unnorm_s1.sum(axis=-2, keepdims=True)
    s1 = T.switch(T.isnan(s1), 0., s1)

    l2 = s1 / s1.sum(axis=-1, keepdims=True)
    expected_dist_l2 = T.tensordot(l2, distances, axes=[[4],[0]])
    
    unnorm_l2 = T.exp(choice_alphas[:,np.newaxis,np.newaxis,np.newaxis,np.newaxis]*-expected_dist_l2)
    shape = unnorm_l2.shape
    _, picsize_index, _, _, state_index = T.mgrid[
        0:shape[0], 
        0:shape[1], 
        0:shape[2], 
        0:shape[3], 
        0:shape[4]
    ]
    unnorm_l2 = T.switch(state_index > picsize_index, 0, unnorm_l2)
    l2 = unnorm_l2 / T.sum(unnorm_l2, axis=-1, keepdims=True)

    l2_language = l2[:,:,T.arange(real_signals_indices.shape[0]), real_signals_indices,:].squeeze()
    costs_real = T.outer(cost_factors, objective_costs_real)

    utility_l2 = T.log(l2_language)

    unnorm_s3 = T.exp(
        alphas[:,np.newaxis,np.newaxis, np.newaxis]*
        (utility_l2 - costs_real[:,np.newaxis,:,np.newaxis])
    )

    s3 = unnorm_s3 / unnorm_s3.sum(axis=-2, keepdims=True)
    s3 = T.switch(T.isnan(s3), 0, s3)
    
    return_value = T.jacobian(s3.flatten(), alphas)

    return tt.function([
                possible_signals_array, 
                real_signals_indices, 
                alphas, 
                choice_alphas, 
                cost_factors, 
                objective_costs_possible, 
                at_most_as_costly, 
                types,
                distances
            ], return_value, on_unused_input='warn'
        )

s3_function_grad = theano_RSA()

Then producing some simulated data:

num_participants = 1
num_states = 2
possible_signals_array = np.array([
    [[0, 0], 
     [0, 0]],
    [[1, 0], 
     [0, 1]]    
])
real_signals_indices = np.array([0])
costs_possible = np.array([1, 1])
at_most_as_costly = np.array([
    [1,1],
    [1,1]
])
types = np.array([1,1])
distances = np.array([
    [0,1],
    [1,0]
])
picked_signals_indices = np.array([0])
picsizes_values = np.array([1])
participants_indices = np.array([0])
states_values = np.array([1])
alphas = np.array([1.])
choice_alphas = np.array([1.])
cost_factors = np.array([0.01])

and finally evaluating the jacobian on that point, it returns all NaNs:

s3_function_grad(
    possible_signals_array = possible_signals_array, 
    real_signals_indices = real_signals_indices, 
    alphas = alphas, 
    choice_alphas = choice_alphas, 
    cost_factors = cost_factors, 
    objective_costs_possible = costs_possible, 
    at_most_as_costly = at_most_as_costly, 
    types = types,
    distances = distances
)

This should be the problem, if PyMC3 is using theano’s Jacobian function for the calculation of the gradient.

I think @twiecki or @aseyboldt will be able to give some pointers here