Not divergences but high R hat

Hi, I am implementing a version of the Stochastic Block Model with very simple priors. I don’t have divergences, but the R-hat is still higher than 1.01 for three of the parameters. I’m not sure how to improve it. Any help?

def custom_sigmoid(x):
    return pm.math.exp(x) / (1 + pm.math.exp(x))

def create_model(data_matrix, num_nodes, num_blocks, a_alpha, b_alpha, a_tau, b_tau, mu_zeta, sigma_zeta):
    
#---------------------------- Data  -----------------------------#
# Considering only triangular superior
    data_matrix = data_matrix.reshape(num_nodes * num_nodes).T
    mask = np.triu(np.ones((num_nodes, num_nodes)), k=1).astype(bool)
    data_matrix = data_matrix.reshape(num_nodes, num_nodes)[mask]

    a_alpha
    b_alpha
    a_tau
    b_tau
    mu_zeta
    sigma_zeta

#---------------------------- Prior Parameters 1 ---------------------------#
    with pm.Model() as SBMmodel:
         # Alpha: Distribución Gamma para los bloques
        alpha = pm.Gamma("alpha", alpha=a_alpha, beta=b_alpha)

        # Omega: Distribución Dirichlet para las probabilidades de pertenencia a bloques
        omega = pm.Dirichlet("omega", a=np.ones(num_blocks) * (alpha / num_blocks), shape=num_blocks)    
            
        # xi. Asignment vector
        E_row = pm.Categorical('E_vector', p=omega, shape=num_nodes)
        
        # tau meassuring variance of probability of interaction between blocks            
        tau2 = pm.InverseGamma("tau2", alpha=a_tau, beta=b_tau)

        # zeta meassuring mean of probability of interaction between blocks 
        zeta_raw = pm.Normal("zeta", mu=mu_zeta, sigma=np.sqrt(sigma_zeta))
        zeta = custom_sigmoid(zeta_raw)
        #zeta = pm.TruncatedNormal("zeta", mu=mu_zeta, sigma=np.sqrt(sigma_zeta), lower=0, upper=1)
        
        # Theta: Probabilidad de interacción entre bloques en la parte triangular superior
        #Theta_kl = pm.TruncatedNormal('Theta_kl', mu=zeta, sigma=np.sqrt(tau2), lower=0, upper=1, shape=(num_blocks * (num_blocks + 1)) // 2)
        Theta_kl_raw = pm.Normal('Theta_kl', mu=zeta, sigma=np.sqrt(tau2), shape=(num_blocks * (num_blocks + 1)) // 2)
        Theta_kl = custom_sigmoid(Theta_kl_raw)
        
        Theta_matrix = np.zeros((num_blocks, num_blocks), dtype=object)
        index = 0
        for i in range(num_blocks):
            for j in range(i, num_blocks):
                Theta_matrix[i, j] = Theta_kl[index]
                Theta_matrix[j, i] = Theta_kl[index]  # simetría
                index += 1
        
#---------------------------- Deterministic function for Theta_E ---------------------------#
# define the matrix Theta_E
# Dimension: num_nodes x num_nodes 
# Each value is the probability of interaction between two nodes depending on the block they belong to
           
    def create_Theta_E_matrix(E_row, num_blocks, num_nodes):
        E_matrix = np.zeros((num_blocks, num_nodes), dtype=int)
        for node_index in range(num_nodes):
                block_index = E_row[node_index]  # Obtener el bloque asignado para el nodo
                E_matrix[block_index, node_index] = 1  # Asignar 1 en la posición correspondiente

        # Inicializar Theta_E_matrix
        Theta_E_matrix = np.zeros((num_nodes, num_nodes, 2), dtype=int)
        indices = [np.where(E_matrix[i] == 1)[0][0] + 1 for i in range(num_nodes)]
        for i in range(num_nodes):
            for j in range(num_nodes):
                I = indices[i]
                K = indices[j]
                Theta_E_matrix[i, j] = (I, K)  # Assign the ordered pair

        return Theta_E_matrix

        rows, cols = len(Theta_E_matrix), len(Theta_E_matrix[0])
        Theta_E_updated = np.zeros((rows, cols), dtype=int)

        # iterate over the ordered pairs in Theta_E_matrix.
        for i in range(rows):
            for j in range(cols):
                coord_i, coord_j = Theta_E_matrix[i][j]

                valor_theta = Theta_matrix[coord_i-1, coord_j-1]

                Theta_E_matrix[i, j] = valor_theta

#---------------------------- Deterministic function for Likelihood ---------------------------#

    def compute_bernoulli_parameters(Theta_E_matrix):
        bernoulli_parameters = []
        for i in range(num_nodes):
            row = []
            for j in range(num_nodes):
                param = pm.math.log(Theta_E_matrix[i][j] / (1 - Theta_E_matrix[i][j]))
                row.append(param)
            bernoulli_parameters.append(row)
        return at.stack(bernoulli_parameters)

        
# Calculate the Bernoulli parameters
        bernoulli_params = compute_bernoulli_parameters(Theta_E_matrix)
        
# Use pm.Deterministic to include the deterministic function in the model
        bernoulli_parameters = pm.Deterministic('bernoulli_parameters', bernoulli_params)
    
# Observed data
        y_adyencence = pm.Bernoulli('y_matrix', p=bernoulli_parameters, observed=data_matrix)
    
    return SBMmodel

Hi Andrea, with rhat warnings, there are basically three cases:

  1. The chains are too short. Sometimes the chains find the same typical set in the posterior and are sampling happily but the chain stops before it becomes obvious that all the chains are in the same neighborhood. So just increasing the draws on the sampler can test that.
  2. The posterior is multi-modal. Sometimes there are two distinct neighborhoods with high posterior probability. The chains will not converge with more samples. It can be helpful to look at the posterior mean by chain with something like trace.posterior.mean(dim="draw"). Or az.plot_trace(trace) also makes this obvious - in the righthand plot, you’ll see separated horizontal lines for one or more parameter. If you are in this situation, it’s worth asking whether all the modes are plausible. If not, you can use tighter priors to nudge the sampler away from bad modes.
  3. low rhat is a poxy for something else. If one of your parameters is very hard to sample from, then the chains might not be able to freely move through the posterior. In this case, you can detect that by looking at the ess statistic. If that is low, the chains are not moving freely. I like az.ess(trace) but it also shows up in the summary.

In each of these cases, you can get the warning without having divergences. They often go together if you are sampling from a really badly behaved posterior but they are measuring different problems.

2 Likes