What are the Reasons for a Variable with few Non-zero Values to Cause the Model to Fail to Converge?

I’m new to the world of bayesian. I’m working on a hierarchical model with quite a few variables. One of the variables is causing the model not to converge. If I use ‘numpyro’, the sampling would finish very quickly and end up with inf r_hat for every variable. When I use ‘pymc’ sampler, it won’t even run. I identified the problematic variable by elimination. 98% of the values are zero for the variable. But the rest 2% are all valid values, nothing crazy. I’m curious what are the reasons for this to happen and if there are any potential fixes if I have to include this variable in the model. Thank you very much.

You have to share more details about the model. numpyro running fast just means it’s getting a 100% divergences.

Hi Ricardo, Thank you very much for your response. Here is what the code structure looks like. It’s a hierarchical model with a nested structure. At least that’s what I’m trying to do, but I’m unsure if I have everything set up correctly. The dependent variables are categorical. When I use ‘pymc’ sampler, it would give me an error saying “Initial evaluation of model at starting point failed!”. Let’s say the variable causing issue is ‘Var1_’, If I remove ‘Var1_’, the sampling process would run. But the values of ‘Var1_’ are all valid in the dataframe, except 98% of the values are 0. I hope this provides more information. I really appreciate the help!

coords = {
    "dim1": dim1,
    "dim2": dim2,
    "dim3": dim3,
    "Variables": Variables,
    "obs": range(N),
}

with pm.Model(coords=coords) as model:

    mu = pm.Normal('mu', 0, 1000, dims="Variables")
    sigma = pm.HalfNormal('sigma', sigma=1000, dims="Variables")

    InterceptC_M = pm.Normal("InterceptC_M", mu[Variables.index("InterceptC_M")], sigma[Variables.index("InterceptC_M")], dims=("dim1","dim3"))
    InterceptL_M = pm.Normal("InterceptL_M", mu[Variables.index("InterceptL_M")], sigma[Variables.index("InterceptL_M")], dims=("dim1","dim3"))
    InterceptF_M = pm.Normal("InterceptF_M", mu[Variables.index("InterceptF_M")], sigma[Variables.index("InterceptF_M")], dims=("dim1","dim3"))

    Var1_ = pm.Normal("Var1_", mu[Variables.index("Var1_")], sigma[Variables.index("Var1_")], dims="dim3")

    u_A = InterceptC_M[dim1_idx, dim3_idx] + Var1_[dim3_idx] * df['Var1_']
    u_B = InterceptL_M[dim1_idx, dim3_idx] + Var1_[dim3_idx] * df['Var1_']
    u_C = InterceptF_M[dim1_idx, dim3_idx] + Var1_[dim3_idx] * df['Var1_']

    TT_stack = pm.math.stack([u_A,u_B,u_C]).T

    p_TT = pm.math.softmax(TT_stack, axis=1)

    like_TT = pm.Potential("like_TT", Weight*pm.logp(pm.Categorical.dist(p=p_TT),dv_TT), dims="obs")

    Intercept = pm.Normal("Intercept", mu[Variables.index("Intercept")], sigma[Variables.index("Intercept")], dims=("dim1","dim3"))
    v = pm.Normal("v", mu[Variables.index("v")], sigma[Variables.index("v")], dims="dim3")
    Var2_ = pm.HalfNormal("Var2_", sigma=sigma[Variables.index("Var2_")], dims='dim3')
    Var3_ = pm.Normal("Var3_", mu[Variables.index("Var3_")], sigma[Variables.index("Var3_")], dims='dim1')

    util=[]
    for i in range(len(dim1)):
        exec('''u_{} = Intercept[dim1_idx, dim3_idx] \
                            - Var2_[dim3_idx] * df['Var2_'] \
                            + Var3_[dim1_idx] * df['Var3_'] \
                            + v[dim3_idx] * pm.math.log(pm.math.sum(pm.math.exp(TT_stack),axis=1))'''.format(dim1[i]))
        exec("util.append(u_{})".format(dim1[i]))

    model_stack=pm.math.stack(util).T

    p_model = pm.math.softmax(model_stack, axis=1)

    like_model = pm.Potential("like_model", Weight*pm.logp(pm.Categorical.dist(p=p_model),dv_model),dims="obs")

    idata = pm.sample(draws=2000,
                      chains=4,
                      cores=4,
                      tune=2000,
                      nuts_sampler="numpyro",
                      idata_kwargs={"log_likelihood": False}, 
                      nuts_sampler_kwargs={"postprocessing_vectorize":"scan"},
                      random_seed=1301)