How to average multiple chains?

Hi all,
this question may have been answered somewhere already, but I couldn’t find a solution that works for v5.
It seems that the posterior of my current model has two distinct modes (I am showing you a smoothed 2D histogram of parameters a and b from 30 individual chains using the NUTS sampler):

Can I simply claim that the mode at a < 0.25e-9 and b = 0.6e-9 is less likely, because less chains seem to have converged onto it?
Or is there a way to properly weight all the chains?

I could get rid of that mode by a constraint a > b, but I would prefer not to do that, as this second mode may also contain some useful information.

So far I have tried:

  • replacing pm.sample with pm.sample_smc, but this didn’t work as an ‘NaN’ was produced somewhere, after running beta value and sample_smc didn’t like that
  • adding a pm.Deterministic with Y_obs.logp() to track the log-likelihood, but that doesn’t seem to be part of v5 anymore.
  • weighing the chains by trace.sample_stats.lp, but that doesn’t seem to be the right thing to do, as trace.sample_stats.lp is very similar for all chains

Any help is highly appreciated!

Running 30 chains is highly unusually under NUTS. Two should be sufficient and 4 is the default. Is there any reason you are running so many? Do the sampling diagnostics suggest that things have converged?

The plot looks the same for 4 chains. I ran 30 to see, if there is a difference. The rhat is quite large (>3), but I assumed that was due to the different chains converging onto the different modes.

I also ran the same model with pm.Metropolis() and 10 chains and the result is quite similar still.
Any suggestions what could be happening?

If your different chains are sitting in different regions of the parameter space, averaging/concatenating them isn’t a good idea. The whole goal of MCMC is to get each chain to be sampling from the posterior. The “only” reason we run more than 1 chain is to actually test this assumption*. If 2 chains are in different parts of the space, then least least one of them (possibly both) is failing to achieve this. So this points to something more fundamental problem and I would expect the model needs to be substantially changed to address it.

*That’s a bit of an oversimplication.

Okay thank you!
So there is no point in trying to rank the chains by their log-likelihood or similar?

If I remember correctly from older posts, there should be the possibility of multiple modes existing in the posterior and then NUTS would have some trouble. I can get rid of this problem, by constraining the model, but wanted to see, if there is a way to actually quantify this ‘multi-modality’.

I think NUTS can do ok with a degree of multi-modality, but it really depends on the exact geometry. Accordingly, one strategy for respecifying your model would be to make changes so as to arrive at a geometry that would allow NUTS to hop between modes. But without knowing more about the model, it’s hard to be specific.

If you have truly mutli-modal outputs you should not be averaging NUTS chains. All the chains should be converging to the same posterior distribution, as evidence of convergence. If you have multi-modal outputs, you should either be using mixture models or a different sampler altogether, such as sequential monte carlo.

2 Likes

Hi,

thank you both for the useful replies. I could remove some of the multi-modality by changing parameters, but not all of it.
When I use pm.sample_smc() I get the error copied below.
My model is built fully in PyTensor and runs normal with NUTS. The likelihood function is generated via pm.Normal(mu=…, observed=…).
Any ideas what could be the issue?

---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "C:\Anaconda3\envs\pymc5_env\Lib\multiprocessing\pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "C:\Anaconda3\envs\pymc5_env\Lib\multiprocessing\pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Anaconda3\envs\pymc5_env\Lib\site-packages\pymc\smc\sampling.py", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Anaconda3\envs\pymc5_env\Lib\site-packages\pymc\smc\sampling.py", line 350, in _sample_smc_int
    smc.update_beta_and_weights()
  File "C:\Anaconda3\envs\pymc5_env\Lib\site-packages\pymc\smc\kernels.py", line 274, in update_beta_and_weights
    ESS = int(np.exp(-logsumexp(log_weights * 2)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: cannot convert float NaN to integer
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[16], line 7
      3 if __name__ == '__main__':
      5     a = np.where((np.array(df['Time'])>(tmin)) & (np.array(df['Time'])<(tmax)))
----> 7     trace = glm_mcmc_inference_diffusion_backup3(df, np.max(a), Fluence, Surface, Thickness, Absorption_coeff)

Cell In[8], line 59, in glm_mcmc_inference_diffusion_backup3(Data_fit, i, Fluence, Surface, Thickness, Absorption_coeff)
     55     print(np.where(Y_obs_late.eval()==np.nan))
     57     #### Draw Samples from the Posterior Distribution
     58     #step_method = pm.NUTS(early_max_treedepth=4, max_treedepth = 5)
---> 59     trace = pm.sample_smc(progressbar=False)#pm.sample(step=step_method, chains=1, draws=100, tune=1000)#step=TRPL_nuts, tune = 50, draws = no_of_samples, chains=3, discard_tuned_samples=True)
     62 return trace

File C:\Anaconda3\envs\pymc5_env\Lib\site-packages\pymc\smc\sampling.py:213, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    210 t1 = time.time()
    212 if cores > 1:
--> 213     results = run_chains_parallel(
    214         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
    215     )
    216 else:
    217     results = run_chains_sequential(
    218         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
    219     )

File C:\Anaconda3\envs\pymc5_env\Lib\site-packages\pymc\smc\sampling.py:388, in run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores)
    386 params = tuple(cloudpickle.dumps(p) for p in params)
    387 kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
--> 388 results = _starmap_with_kwargs(
    389     pool,
    390     to_run,
    391     [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
    392     repeat(kernel_kwargs),
    393 )
    394 results = tuple(cloudpickle.loads(r) for r in results)
    395 pool.close()

File C:\Anaconda3\envs\pymc5_env\Lib\site-packages\pymc\smc\sampling.py:415, in _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter)
    411 def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
    412     # Helper function to allow kwargs with Pool.starmap
    413     # Copied from https://stackoverflow.com/a/53173433/13311693
    414     args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
--> 415     return pool.starmap(_apply_args_and_kwargs, args_for_starmap)

File C:\Anaconda3\envs\pymc5_env\Lib\multiprocessing\pool.py:375, in Pool.starmap(self, func, iterable, chunksize)
    369 def starmap(self, func, iterable, chunksize=None):
    370     '''
    371     Like `map()` method but the elements of the `iterable` are expected to
    372     be iterables as well and will be unpacked as arguments. Hence
    373     `func` and (a, b) becomes func(a, b).
    374     '''
--> 375     return self._map_async(func, iterable, starmapstar, chunksize).get()

File C:\Anaconda3\envs\pymc5_env\Lib\multiprocessing\pool.py:774, in ApplyResult.get(self, timeout)
    772     return self._value
    773 else:
--> 774     raise self._value

ValueError: cannot convert float NaN to integer

@ricardoV94 may be more familiar with the SMC sampler.

The are nan in the log-weights, could your logp be returning nan at some point? Why so?

When I do pm.logp(Y_obs, test_val).eval() I don’t get any NaN values.
Is there anything the smc sampler could struggle with that I am not aware of, like HalfNormal distributions, matrix output of the function I am using, pytensor.dimshuffle() or pytensor.switch()?

If you are using a model with custom operations, it would be helpful if you could share it. It sounds to me like you get nans somewhere (not the initial point), so maybe add a res = switch(isnan(x), -np.inf, res) if you have a custom logp and see how it fares.

You should however, figure out the source of your nan, usually that highlights a bug or a bad parametrization.

Hi,
here is the important part of the code:

def mcmc_inference_diffusion(Data, Surface, Thickness, A_fact):

    
    #### Setting up the Data and Timeframe
    time = np.array(Data['Time'])[:i+1]*1e-9    #s
    y_combined = np.zeros((len(time),len(Surface)))
   
    for a in range(len(Surface)):
        y_combined[:,a] = np.array(Data[str(a)])[:i+1]
    
    
    with pm.Model() as model:
        
        #### Defining Model Parameters
        ## Diffusion
        S_ratio_power = pm.Normal("S_ratio_power",0,3, initval=1)
        S_ratio_model = pm.Deterministic('S_ratio_model', 10**(S_ratio_power))

        Beta_val_power = pm.HalfNormal("Beta_val_power",1, initval=1)
        Beta_val_model = pm.Deterministic("Beta_val_model", np.exp(-Beta_val_power))
    
        S_substrate_power_offset = pm.HalfNormal('S_substrate_power_offset',1, initval=0.1)
        S_substrate_model = pm.Deterministic('S_substrate_model', 10**(S_substrate_power_offset*3))
        
        
        ## Recombination
        k1_power_offset = pm.HalfNormal('k1_power_offset', 1)
        k1_model = pm.Deterministic('k1_model', 10**(3 + k1_power_offset*2))

        p0_power_offset = pm.HalfNormal('p0_power_offset', 1)
        p0_model = pm.Deterministic('p0_model', 10**(13 + p0_power_offset*2))
                
        
        #### Simulation of Time-Resolved PL
        N_calc = diffusion(S_ratio_model, Beta_val_model, S_substrate_model, k1_model, p0_model, shared(time), shared(Surface),Thickness, shared(A_fact))    
        
        ## Likelihood Function (Normal Distribution)
        Y_obs = pm.Normal("Y_obs_late", mu=at.sqrt(y_combined[1:i,:]/N_calc[1:i,:]), sigma=1,  observed=np.ones(shape=np.shape(y_combined[1:i,:]))) #sigma=np.sqrt((i-max_locator)/5000)       
                
        #### Draw Samples from the Posterior Distribution
        trace = pm.sample_smc(progressbar=False)
        
    return trace

The function diffusion() is:

def diffusion(S_ratio_model, Beta_val_model, S1_model, k1_model, p0_model, time, Surface, thickness, A_fact):
          
    ### Bulk Recombination Parameters    
    k1 = k1_model              
    p0 = p0_model               
      
    beta_est = beta_estimator(thickness, S1_model, S_ratio_model, Beta_val_model)  # Polynomial Function
    Diffusion = at.switch(at.gt(beta_diffusion,0.5),Diff_calc_high(S1,S2,beta_diffusion,thickness),Diff_calc_low(S1,S2,beta_diffusion,thickness))  #Diff_calc_high and _low are polynomial functions
   
    
    beta_model = beta_est*np.pi/(thickness*1e-7)

    S_front = at.switch(at.eq(Surface,1),S1_model,S1_model*S_ratio_model)

    ### Equation (24)   
    S_front_4d = S_front.dimshuffle(0,'x','x','x')
    S_front_4d.broadcastable
    (False, True, True, True)
        
    Diffusion_4d = Diffusion#.dimshuffle(0,'x','x','x')

    beta_4d = beta_model.dimshuffle('x', 0 ,'x','x')
    beta_4d.broadcastable
    (True, False, True, True)
   
    ### Defining the spacial domains
    x = np.arange(20)
    stretch = 2.5   # Factor that stretches the tanh-grid towards the edges
    thickness_tanh_spacing = thickness/2*(1-np.tanh(stretch*(1-2*x/len(x)))/np.tanh(stretch))
    
    z_array = at.as_tensor_variable(thickness_tanh_spacing*1e-7)
    z_array_4d = z_array.dimshuffle('x','x',0, 'x')
    z_array_4d.broadcastable
    (True, True, False, True)

    U_z = at.cos(beta_4d*z_array_4d)+S_front_4d/(Diffusion_4d*beta_4d)*at.sin(beta_4d*z_array_4d)
    U_z.broadcastable
    (False, False, False, True)

    ### Equation (25)
    A_fact_4d= (A_fact).dimshuffle(0,'x','x','x')
    A_fact_4d.broadcastable
    (False, True, True, True)
    
    A_param = ((at.exp(-A_fact_4d*z_array_4d)*U_z).sum(axis=2)/(at.power(U_z,2)).sum(axis=2))

    A_param_4d = A_param.dimshuffle(0,1,2,'x')
    A_param_4d.broadcastable
    (False, False, True, True)
   
    time_4d = time.dimshuffle('x','x','x',0)
    time_4d.broadcastable
    (True, True, True, False)

    n_tz0 = (A_param_4d*U_z * at.exp(-(Diffusion_4d*at.power(beta_4d,2))*time_4d)).sum(axis=1) # sum over all beta values
    
    ## Recombination
    time_3d = time.dimshuffle('x','x',0)
    time_3d.broadcastable
    (True, True, False)
    
    n_tz1 = (n_tz0*at.exp(-(k1*(n_tz0+2*p0)/(n_tz0+p0))*time_3d))
    n_tz = n_tz1.sum(axis=1)# sum over thickness
    
    ### Equation (13) & (32)  
    I_t = (2*p0*n_tz+n_tz**2)

    ### Normalizing to compare with the data        
    PL_value_norm = at.mul(I_t.T,1/I_t.max(axis=1))

    return PL_value_norm

Sorry for the long code example.
Any help is greatly appreciated!

I don’t quite know, how pymc calculates the logp(), but is it possible that a chain runs into a region where (simulation - data)^2 becomes too large and causes and overflow so that logp() becomes NaN?

I tried running part of the code only and replaced all exp(x) with exp(x)/(1 + exp(x) so that it results in 0, if the exponential overflows. Now I get this error:

RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)

For some reason within_chain_variance seems to be zero now for at least one of the chains.

Is it possible that the new expression is producing zeros for all draws and thus there is no variance?

Yes that seems to be happening.
I limited the range of the parameters and now it runs!
Thank you for your help!

One last question: Is there a way to use progressbar = True?
I know that this is a common error, but is there a way to use it with pm.sample_smc() now?

1 Like

The progress bar should be working. If you are using VSCode there seems to be a bug in the progress bar package that prevents it from displaying.