Restarting sampling from stored multitrace: mixed sampling and tuning

I am fitting some data on a server, and the fitting takes more time (about 500 hours) than the maximum time limit for a job on the server (120 hours). So I thought the natural solution would be to run the model over multiple jobs.

First, I have tried using DMTCP to checkpoint the state of the program, but the normal setup wouldn’t work (raised an error about “restore_brk: error: new/current break != saved break”). Now, I am trying to just save the trace at the end of the job, and then restart by using the “trace” argument in pm.sample. However, two things are unclear to me about this:

  1. My model uses mixed sampling (NUTS and Gibbs). Is the sample function going to automatically account for that?
  2. Do I only need to tune in the first run, and then set the tune parameter in pm.sample to 0 for the successive runs? In pure Metropolis-Hasting the state of the sampler is just a point in the parameter space, but I know that NUTS uses the tuning steps to tune hyperparameters, and I am unsure whether these are stored in the trace object.

Alternatively, is there a way to store the whole state of the running sampling and then pick it up in a different job from where it stopped?

mixed sampling will make this pretty difficult, as we can pickle the states after tuning of the NUTS sampler, but I am not sure all the meta-state after tuning from other samplers are stored. I guess I am more interested to know why the sampling takes so long (complex model or lots of data), and whether there are ways to improve that.

Hi!

Thank you for your answer. I think the problem is the complexity of the model. The data is from 60 participants, each producing around 150 data points, so really not that many datapoints.

The model consists of a cognitive model, written in Theano, which returns, for a certain “state”, a tensor encoding the probability that a participant will produce each of a series of signals:

def theano_RSA(
        possible_signals_array=T.lmatrix("possible_signals_array"), 
        real_signals_indices=T.lvector("real_signals_indices"), 
        alphas=T.dvector("alphas"), 
        choice_alphas=T.dvector("choice_alphas"), 
        cost_factors=T.dvector("cost_factors"), 
        objective_costs_possible=T.dvector("objective_costs_possible"), 
        at_most_as_costly=T.lmatrix("at_most_as_costly"), 
        types=T.lvector("types"),
        distances=T.dmatrix("distances"), 
        consider_costs=False,
        return_symbolic=True,
        return_gradient=False,
        return_variables=False):

    real_signals_array = possible_signals_array[real_signals_indices]

	considered_signals = types.dimshuffle(0, 'x') >= types.dimshuffle('x', 0)

    unique_alt_profiles, index_signals_profile = T.extra_ops.Unique(
        axis=0, return_inverse=True)(considered_signals)

    temp_array = T.eq(index_signals_profile.dimshuffle(0,'x'),
            index_signals_profile.dimshuffle('x',0))
    cumsum = T.extra_ops.cumsum(temp_array, axis=1)
    signals_index_within_profile = T.nlinalg.diag(cumsum)

	l0_unreshaped = possible_signals_array / possible_signals_array.sum(axis=-1, keepdims=True)
	l0 = l0_unreshaped[np.newaxis,:,:]
    l0 = l0[:,np.newaxis,:,:]
    
	unnorm_s1 = l0_extended ** alphas[:,np.newaxis, np.newaxis, np.newaxis]
    s1_unnorm = T.switch(
        unique_alt_profiles[np.newaxis,:,:,np.newaxis],
        unnorm_s1, 
        1
    )
    s1 = s1_unnorm / s1_unnorm.sum(axis=-2, keepdims=True)

    l2 = s1 / s1.sum(axis=-1, keepdims=True)
    expected_dist_l2 = T.tensordot(l2, distances, axes=[[3],[0]])
    unnorm_l2 = T.exp(
        choice_alphas[:,np.newaxis,np.newaxis,np.newaxis]*
        -expected_dist_l2
    )

    l2 = unnorm_l2 / T.sum(unnorm_l2, axis=-1, keepdims=True)

    l2_language_possible = l2[:,
        index_signals_profile,
        signals_index_within_profile,
        :]

    l2_language = l2_language_possible[:,real_signals_indices,:]
	unnorm_s3 = l2_language**alphas[:,np.newaxis,np.newaxis]
    s3 = unnorm_s3 / unnorm_s3.sum(axis=-2, keepdims=True)

	if return_variables:
		return s3, {                
			"possible_signals_array": possible_signals_array, 
			"real_signals_indices": real_signals_indices, 
			"alphas": alphas, 
			"choice_alphas": choice_alphas, 
			"cost_factors": cost_factors, 
			"objective_costs_possible": objective_costs_possible, 
			"at_most_as_costly": at_most_as_costly, 
			"types": types,
			"distances": distances
		}
	else:
		return s3

And the pymc3 part:

def create_model(num_participants, num_states, possible_signals_array,
            real_signals_indices, costs_possible, at_most_as_costly, 
            types, distances, picked_signals_indices,
            picsizes_values, participants_indices, states_values,
            consider_costs, names):

    with pm.Model() as model:
		### hyperprior for population-level parameters over alphas
		pop_alpha_mu = pm.HalfNormal("pop_alpha_mu", sigma=1)
		pop_alpha_sigma = pm.HalfNormal("pop_alpha_sigma", sigma=1)

		### hyperprior for population-level parameters over choice_alphas
		pop_choice_alpha_mu = pm.HalfNormal("pop_choice_alpha_mu", sigma=1)
		pop_choice_alpha_sigma = pm.HalfNormal("pop_choice_alpha_sigma", sigma=1)

        alphas = pm.Gamma(
            "alphas",
            mu=pop_alpha_mu,
            sigma=pop_alpha_sigma,
            shape=num_participants
        )

        choice_alphas = pm.Gamma(
            "choice_alphas", 
            mu=pop_choice_alpha_mu,
            sigma=pop_choice_alpha_sigma,
            shape=num_participants
        )

        min_picsize = min(picsizes_values)

        p_accept_pre_error = T.zeros(
            shape=(len(picsizes_values), len(real_signals_indices))
        )
        for state in range(min_picsize, num_states):

            arguments = {
                'possible_signals_array': tt.shared(
                    possible_signals_array[state], name="possible_signals_array"), 
                'real_signals_indices': tt.shared(
                    real_signals_indices, name="real_signals_indices"), 
                'types': tt.shared(
                    types, name="types"),
                'distances': tt.shared(
                    distances[:state+1,:state+1], name="distances"),
                'alphas': alphas, 
                'choice_alphas': choice_alphas
            }

            s3 = theano_RSA(
                **arguments,
                consider_costs=consider_costs,
                return_symbolic=True
            )

            relevant_indices = (picsizes_values == state).nonzero()[0]
            
            subtensor = s3[
                participants_indices[relevant_indices],:,
                states_values[relevant_indices]
            ]
			
            p_accept_pre_error = T.set_subtensor(
                p_accept_pre_error[relevant_indices],
                subtensor
            )
		
        # TODO: add noise to categorical
        probability_accept = p_accept_pre_error
        pm.Deterministic("probability_accept", probability_accept)

        ### observed
        obs = pm.Categorical(
            "picked_signals", 
            p=probability_accept, 
            shape=len(picsizes_values), 
            observed=picked_signals_indices
        )


    return model

The ugliest part is the loop where I progressively fill “p_accept_pre_error”. The problem is that adding one dimension to the theano_RSA function and vectorizing across states breaks the gradient, because some NaNs appear in a switch statement (the switch excludes those NaNs, but the gradient breaks anyway).

And finally the sampling:

step = pm.NUTS(target_accept=0.99)
trace=pm.sample(
	step=step,
	draws=2000,
	tune=3000
)

When I do prior checks for a subset of the data (10 participants), I also get converge problems (the chain did not converge).

I would be very grateful for any idea on how to improve the model! Thank you again.

EDIT -------------------

I realized now that I had mistakenly commented out the “observed” line in the model, and therefore I was basically doing prior checks. Once I add the observed, the model runs much much faster (around 20 hours for the full dataset). What could be the reason for that?

I am running the model with the observed data now, hopefully this time it will converge.