Hi!
I’m working the Hodgkin–Huxley model and I’m trying to run it with multiple cores. I’m using odeint from scipy and the following code:
@as_op(itypes=[pt.dvector, pt.dvector], otypes=[pt.dvector])
def pytensor_forward_model_matrix(theta, time):
return solve_ode(theta, time, current=20)
def run_sampling():
with pm.Model() as model:
# Priors using Gamma distributions
gNa = pm.Gamma("gNa", alpha=2, beta=1/60) # Gamma with mean ~120, alpha=2, scale=60 -> scale = 1/beta
gK = pm.Gamma("gK", alpha=3, beta=1/12) # Gamma with mean ~36, alpha=3, scale=12
gl = pm.Gamma("gl", alpha=2, beta=1/0.15) # Gamma with mean ~0.3, alpha=2, scale=0.15
Cm = pm.Gamma("Cm", alpha=2, beta=1/0.5) # Gamma with mean ~1, alpha=2, scale=0.5
# Priors using Normal distributions
VNa = pm.Normal("VNa", mu=theta[3], sigma=5) # Normal distribution for VNa
Vk = pm.Normal("Vk", mu=theta[4], sigma=5) # Normal distribution for Vk
Vl = pm.Normal("Vl", mu=theta[5], sigma=5) # Normal distribution for Vl
# Likelihood
sigma = pm.HalfNormal("sigma", sigma=10) # Half-normal prior for sigma (standard deviation)
# Finer time resolution for solving the ODE
time_fine = np.arange(0, 40.0, 0.01) # Finer time for ODE solution
time_tensor = pt.as_tensor_variable(time_fine)
# Solve the ODE at a finer resolution
ode_solution = pytensor_forward_model_matrix(
pm.math.stack([gNa, gK, gl, VNa, Vk, Vl, Cm]),
time_tensor
)
# Downsample the ODE solution to match the observed data
# Keep every 100th point from the fine resolution to match the 1ms step
ode_solution_downsampled = ode_solution[::100]
# Define the likelihood: observed data compared to the model's ODE solution
pm.Normal("obs", mu=ode_solution_downsampled, sigma=sigma, observed=observed_data) #THIS IS THE LIKELIHOOD
# Get the number of available CPU cores
num_cores = mp.cpu_count()
trace = pm.sample(tune=3000, draws=10000, cores=num_cores, chains=4)
return trace, model
if __name__ == '__main__':
sampler = "NUTS"
# Set multiprocessing context
mp.set_start_method('spawn', force=True)
# Run the sampling
trace, model = run_sampling()
but i’m getting the error:
RemoteTraceback Traceback (most recent call last) RemoteTraceback: “”" Traceback (most recent call last): File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py”, line 117, in _unpickle_step_method self._step_method = cloudpickle.loads(self._step_method) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytensor\compile\ops.py”, line 221, in load_back obj = getattr(module, name) ^^^^^^^^^^^^^^^^^^^^^ AttributeError: module ‘main’ has no attribute ‘pytensor_forward_model_matrix’ During handling of the above exception, another exception occurred: Traceback (most recent call last): File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py”, line 126, in run self._unpickle_step_method() File “c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py”, line 119, in _unpickle_step_method raise ValueError(unpickle_error) ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver. “”" The above exception was the direct cause of the following exception: ValueError Traceback (most recent call last) ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver. The above exception was the direct cause of the following exception: ParallelSamplingError Traceback (most recent call last) Cell In[17], line 7 4 mp.set_start_method(‘spawn’, force=True) 6 # Run the sampling ----> 7 trace, model = run_sampling() Cell In[15], line 43 40 # Get the number of available CPU cores 41 num_cores = mp.cpu_count() —> 43 trace = pm.sample(tune=3000, draws=10000, cores=num_cores, chains=4) 45 return trace, model File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\mcmc.py:842, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs) [840](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:840) _print_step_hierarchy(step) [841](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:841) try: → [842](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:842) _mp_sample(**sample_args, **parallel_args) [843](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:843) except pickle.PickleError: [844](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:844) _log.warning(“Could not pickle model, sampling singlethreaded.”) File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\mcmc.py:1255, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs) [1253](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1253) try: [1254](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1254) with sampler: → [1255](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1255) for draw in sampler: [1256](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1256) strace = traces[draw.chain] [1257](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/mcmc.py:1257) strace.record(draw.point, draw.stats) File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py:471, in ParallelSampler.iter(self) [464](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:464) task = progress.add_task( [465](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:465) self._desc.format(self), [466](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:466) completed=self._completed_draws, [467](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:467) total=self._total_draws, [468](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:468) ) [470](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:470) while self._active: → [471](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:471) draw = ProcessAdapter.recv_draw(self._active) [472](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:472) proc, is_last, draw, tuning, stats = draw [473](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:473) self._completed_draws += 1 File c:\Users\camilo\AppData\Local\Programs\Python\Python311\Lib\site-packages\pymc\sampling\parallel.py:338, in ProcessAdapter.recv_draw(processes, timeout) [336](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:336) else: [337](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:337) error = RuntimeError(f"Chain {proc.chain} failed.") → [338](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:338) raise error from old_error [339](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:339) elif msg[0] == “writing_done”: [340](file:///C:/Users/camilo/AppData/Local/Programs/Python/Python311/Lib/site-packages/pymc/sampling/parallel.py:340) proc._readable = True ParallelSamplingError: Chain 0 failed with: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver.