Unused variable causing divergences?

While working on exercise 8, chapter 3, from “Bayesian Analysis with Python” by Osvaldo Martin, I ended up writing the following

import seaborn as sns
import pymc3 as pm
import arviz as az
import pandas as pd
import numpy as np

tips = sns.load_dataset("tips")

tip = tips['tip'].to_numpy()
idx = pd.Categorical(tips['day'],
                     categories=['Thur', 'Fri', 'Sat', 'Sun']).codes
groups = len(np.unique(idx))

with pm.Model() as comparing_groups_hyper:
    
    μ_μ = pm.Normal('μ_μ', 0, sd=10)
    
    μ = pm.Normal('μ', mu=μ_μ, sd=10, shape=groups)
    σ = pm.HalfNormal('σ', sd=3, shape=groups)

    y = pm.Normal('y', mu=μ[idx], sd=σ[idx], observed=tip)

    trace_cg_hyper = pm.sample(1000, return_inferencedata=True)
    
az.summary(trace_cg_hyper)

, which works without divergences.

If I then add in an extra variable, but don’t use it anywhere, then all of a sudden I get hundreds of divergences:

import seaborn as sns
import pymc3 as pm
import arviz as az
import pandas as pd
import numpy as np

tips = sns.load_dataset("tips")

tip = tips['tip'].to_numpy()
idx = pd.Categorical(tips['day'],
                     categories=['Thur', 'Fri', 'Sat', 'Sun']).codes
groups = len(np.unique(idx))

with pm.Model() as comparing_groups_hyper:
    
    μ_μ = pm.Normal('μ_μ', 0, sd=10)
    σ_μ = pm.HalfNormal('σ_μ', 10)
    
    μ = pm.Normal('μ', mu=μ_μ, sd=10, shape=groups)
    σ = pm.HalfNormal('σ', sd=3, shape=groups)

    y = pm.Normal('y', mu=μ[idx], sd=σ[idx], observed=tip)

    trace_cg_hyper = pm.sample(1000, return_inferencedata=True)
    
az.summary(trace_cg_hyper)

I find this quite surprising / confusing - if anyone could help me understand why that’s the case, it would be appreciated!

Hmmm looking at the Parallel plot az.plot_parallel(trace_cg_hyper, var_names=['μ','σ','μ_μ','σ_μ']); it doesnt seem the divergences are concentrated at a specific region. I dont have a good intuition of why you would get divergence, but one explanation is that the unused variable is at a larger scale than other random variables, and thus need larger velocity and resulting a larger energy at each sample (more risk of divergence due to numerical imprecision)

1 Like

Thanks @junpenglao for your answer!

I still get divergences if I set the unused variable at a smaller scale though:

σ_μ = pm.HalfNormal('σ_μ', .01)

Any ideas on how to set it such that the model will converge? I’d like to be able to use it as a prior for
μ’s sigma

When you are using it as prior it should be fine.

I still get hundreds of divergences by running this though:

import seaborn as sns
import pymc3 as pm
import arviz as az
import pandas as pd
import numpy as np

tips = sns.load_dataset("tips")

tip = tips['tip'].to_numpy()
idx = pd.Categorical(tips['day'],
                     categories=['Thur', 'Fri', 'Sat', 'Sun']).codes
groups = len(np.unique(idx))

with pm.Model() as comparing_groups_hyper:
    
    μ_μ = pm.Normal('μ_μ', 0, sd=10)
    σ_μ = pm.HalfNormal('σ_μ', 1)
    
    μ = pm.Normal('μ', mu=μ_μ, sd=σ_μ, shape=groups)
    σ = pm.HalfNormal('σ', sd=3, shape=groups)

    y = pm.Normal('y', mu=μ[idx], sd=σ[idx], observed=tip)

    trace_cg_hyper = pm.sample(1000, return_inferencedata=True)

Hi Marco, nice seeing you here :slight_smile:

Did you try some prior predictive checks? “Hyper standard deviations” (i.e stds at the hyper-prior level in hierarchical models) are usually hard to identify, so I’d do some checks prior to fitting, to understand how μ_μ, σ_μ and σ interact – my hunch is that an std of 10 for the hyper mean is too big when you try to infer the hyper std as well :man_shrugging:

2 Likes

+1 to what @AlexAndorra said, since you have 4 groups there is not a lot of information to infer the hyper std - think about it this way: you want to infer 2 parameters (mu and sigma) from 4 observations. Setting a stronger prior that makes σ_μ away from 0 should help (where most divergence comes from):

with pm.Model() as comparing_groups_hyper:
    
    μ_μ = pm.Normal('μ_μ', 0, sigma=10)
    σ_μ = pm.Gamma('σ_μ', mu=.5, sigma=.1)
    
    μ = pm.Normal('μ', mu=μ_μ, sigma=σ_μ, shape=groups)
    σ = pm.HalfNormal('σ', sd=3, shape=groups)

    y = pm.Normal('y', mu=μ[idx], sigma=σ[idx], observed=tip)

    trace_cg_hyper = pm.sample(1000, return_inferencedata=True)
3 Likes

Thanks @AlexAndorra, @junpenglao ! I have tried greatly restricting the sigmas, and like this I’m able to get convergence:

with pm.Model() as comparing_groups_hyper:
    
    μ_μ = pm.Normal('μ_μ', 3, sigma=.5)
    σ_μ = pm.Gamma('σ_μ', mu=.5, sigma=.1)
    
    μ = pm.Normal('μ', mu=μ_μ, sigma=σ_μ, shape=groups)
    σ = pm.HalfNormal('σ', sd=3, shape=groups)

    y = pm.Normal('y', mu=μ[idx], sigma=σ[idx], observed=tip)

    trace_cg_hyper = pm.sample(1000, return_inferencedata=True)
2 Likes