Problem with divergence in "localized" models: extreme values for Half Cauchy variables

Per “A General Method for Robust Bayesian Modeling” (Wang & Blei) we have been trying to build “localized” models to be more robust to outliers.

For these models, we replace our output variable, Out ~ TruncNorm(ϕ, 𝜎, 0, 𝒏) with a variable that has a vector of iid 𝜎’s, one per data point. In this case, we use Half Cauchy RVs for the 𝝈’s.

Our non-localized model, on some toy cases (without outliers), works fine. But to our surprise when we use the localized model on the same situation, we get divergences. When we look at the diverging points, we see 𝝈(i) values that spike up to, e.g., 500, 1500, etc.

That happens on toy problems, with data generated by scipy.stats.truncnorm. On the real problems, where we try to find the posterior for actual data, we get whole traces that are entirely divergent.

One other thing we note, on the toy problem, when we set the truncated normal to have a lower bound of 0 - 𝝐 instead of zero, the divergence doesn’t happen, and on the real data, we get fewer divergences.

Can anyone offer any insights here?

Just to copy the language of the paper here, the typical Bayesian model

x_i | \beta \mathop{\sim}_{\mathrm{i.i.d.}} p(x_i|\beta) \;\;\; \beta \sim p(\beta | \alpha)

is replaced wiht a hierarchical model

x_i | \beta_i \mathop{\sim} p(x_i | \beta_i) \;\;\; \beta_i \mathop{\sim}_{\mathrm{i.i.d.}} p(\beta_i | \alpha)

In effect this assigns one parameter per datapoint, which I would expect to generate lots and lots of divergences. Ideally you could choose p(\beta_i|\alpha) so that it can be marginalized analytically so that

p(x_i|\alpha) = \int p(x_i|\beta_i)p(\beta_i|\alpha)d\beta_i

is closed-form. Given that section 3.1 is a brief explanation of conjugate exponential families, it seems the authors expect practitioners to either (i) judiciously choose priors to make the above integral have an analytic solution; or (ii) apply variational inference [sections 4+5].


Based on my reading of the above, it appears that you have \beta_i parametrizing the scale \sigma^2 of a truncated normal, and p(\beta_i | \alpha) as \mathrm{HalfCauchy}(\tau) for some fixed \tau. Given that you only have a single point from which to estimate the scale, the model would appear to be barely determined; and the HalfCauchy tail is not helping.

I suspect that HalfNormal variables, particularly those realized near to 0, aren’t particularly informative as to the underlying \sigma, since P(0|\sigma) \propto \frac{1}{\sigma}. Even though \sigma=10 and \sigma=20 differ significantly, the probability of getting a point near 0 isn’t all that different (only O(2:1)); so most points won’t be particularly informative to the value of \sigma_i.

I suspect the most effective way to sample from a local version of this truncated normal model (which may not match the prior you’d most like to use!) would be to parametrize the \sigma_i on the log scale; and (since we are sampling anyway) to use a prior on \alpha. This would let you use a constraining light tail for sampling, while using the heavy tailed lognormal for the likelihood.

Notably, while I get divergences for HalfCauchy and HalfStudentT(3) distributions, this parameterization worked well for me:

with pm.Model() as local_reparam:
    scale_log = np.log(alpha)
    scale_offset_sigma = pm.HalfNormal('scale_offset_sigma', 1)
    scale_offset_log = pm.Normal('scale_log', mu=0, sigma=scale_offset_sigma, shape=N)
    scale = pm.Deterministic('sigma_i', tt.exp(scale_log + scale_offset_log))
    values = pm.HalfNormal('x_i', scale, observed=dataset_realizations['x_i'][0,:])
    local_repar_tr = pm.sample(800, tune=1500, chains=8, cores=2, nuts_kwargs=dict(target_accept=0.9))
    
pm.traceplot(local_repar_tr, ['scale_offset_sigma', 'sigma_i'], 
              coords={'sigma_i_dim_0': range(5)});

And I just used pymc3 to generate the data

alpha = 0.5
N = 80
N_data=10

# generate data
with pm.Model() as mod:
    scales = pm.HalfStudentT('sigma_i', nu=3, sigma=alpha, shape=(N,))
    values = pm.HalfNormal('x_i', scales, shape=(N,))
    
    dataset_realizations = pm.sample_prior_predictive(N_data)
2 Likes