Using multiple cores with ODE Bayesian Inference

Hi!

I’m working the Hodgkin–Huxley model and I’m trying to run it with multiple cores. I’m using odeint from scipy and the following code:

@as_op(itypes=[pt.dvector, pt.dvector], otypes=[pt.dvector])
def pytensor_forward_model_matrix(theta, time):
    return solve_ode(theta, time, current=20)

def run_sampling():
    with pm.Model() as model:
        # Priors using Gamma distributions
        gNa = pm.Gamma("gNa", alpha=2, beta=1/60)  # Gamma with mean ~120, alpha=2, scale=60 -> scale = 1/beta
        gK = pm.Gamma("gK", alpha=3, beta=1/12)    # Gamma with mean ~36, alpha=3, scale=12
        gl = pm.Gamma("gl", alpha=2, beta=1/0.15)  # Gamma with mean ~0.3, alpha=2, scale=0.15
        Cm = pm.Gamma("Cm", alpha=2, beta=1/0.5)   # Gamma with mean ~1, alpha=2, scale=0.5

        # Priors using Normal distributions
        VNa = pm.Normal("VNa", mu=theta[3], sigma=5)  # Normal distribution for VNa
        Vk = pm.Normal("Vk", mu=theta[4], sigma=5)    # Normal distribution for Vk
        Vl = pm.Normal("Vl", mu=theta[5], sigma=5)    # Normal distribution for Vl

        # Likelihood
        sigma = pm.HalfNormal("sigma", sigma=10)  # Half-normal prior for sigma (standard deviation)

        # Finer time resolution for solving the ODE
        time_fine = np.arange(0, 40.0, 0.01)  # Finer time for ODE solution
        time_tensor = pt.as_tensor_variable(time_fine)

        # Solve the ODE at a finer resolution
        ode_solution = pytensor_forward_model_matrix(
            pm.math.stack([gNa, gK, gl, VNa, Vk, Vl, Cm]),
            time_tensor
        )

        # Downsample the ODE solution to match the observed data
        # Keep every 100th point from the fine resolution to match the 1ms step
        ode_solution_downsampled = ode_solution[::100]

        # Define the likelihood: observed data compared to the model's ODE solution
        pm.Normal("obs", mu=ode_solution_downsampled, sigma=sigma, observed=observed_data) #THIS IS THE LIKELIHOOD

        # Get the number of available CPU cores
        num_cores = mp.cpu_count()

        trace = pm.sample(tune=3000, draws=10000, cores=num_cores, chains=4)

    return trace, model

if __name__ == '__main__':
    
    sampler = "NUTS"

    # Set multiprocessing context
    mp.set_start_method('spawn', force=True)

    # Run the sampling
    trace, model = run_sampling()

but i’m getting the error:

RemoteTraceback Traceback (most recent call last) RemoteTraceback: “”" Traceback (most recent call last): File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py”, line 117, in _unpickle_step_method self._step_method = cloudpickle.loads(self._step_method) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytensor\compile\ops.py”, line 221, in load_back obj = getattr(module, name) ^^^^^^^^^^^^^^^^^^^^^ AttributeError: module ‘main’ has no attribute ‘pytensor_forward_model_matrix’ During handling of the above exception, another exception occurred: Traceback (most recent call last): File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py”, line 126, in run self._unpickle_step_method() File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py”, line 119, in _unpickle_step_method raise ValueError(unpickle_error) ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver. “”" The above exception was the direct cause of the following exception: ValueError Traceback (most recent call last) ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver. The above exception was the direct cause of the following exception: ParallelSamplingError Traceback (most recent call last) Cell In[17], line 7 4 mp.set_start_method(‘spawn’, force=True) 6 # Run the sampling ----> 7 trace, model = run_sampling() Cell In[15], line 43 40 # Get the number of available CPU cores 41 num_cores = mp.cpu_count() —> 43 trace = pm.sample(tune=3000, draws=10000, cores=num_cores, chains=4) 45 return trace, model File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\mcmc.py:842, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs) [840](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:840) _print_step_hierarchy(step) [841](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:841) try: → [842](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:842) _mp_sample(**sample_args, **parallel_args) [843](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:843) except pickle.PickleError: [844](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:844) _log.warning(“Could not pickle model, sampling singlethreaded.”) File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\mcmc.py:1255, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs) [1253](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1253) try: [1254](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1254) with sampler: → [1255](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1255) for draw in sampler: [1256](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1256) strace = traces[draw.chain] [1257](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1257) strace.record(draw.point, draw.stats) File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py:471, in ParallelSampler.iter(self) [464](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:464) task = progress.add_task( [465](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:465) self._desc.format(self), [466](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:466) completed=self._completed_draws, [467](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:467) total=self._total_draws, [468](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:468) ) [470](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:470) while self._active: → [471](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:471) draw = ProcessAdapter.recv_draw(self._active) [472](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:472) proc, is_last, draw, tuning, stats = draw [473](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:473) self._completed_draws += 1 File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py:338, in ProcessAdapter.recv_draw(processes, timeout) [336](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:336) else: [337](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:337) error = RuntimeError(f"Chain {proc.chain} failed.") → [338](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:338) raise error from old_error [339](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:339) elif msg[0] == “writing_done”: [340](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:340) proc._readable = True ParallelSamplingError: Chain 0 failed with: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver.

1 Like

I am getting the same error in windows 11. Visual Studio Code.

1 Like