Truncated Normal

hi!
I am having trouble with truncatednormal distributions. In my understanding, these are just normal distributions with a specified restriction on upper and lower bounds.

I however get divergencies when using them, even on the simplest case (code below). Can anyone help me explain this?

import pymc3 as pm

import matplotlib.pyplot as plt

with pm.Model() as model:

    N = pm.TruncatedNormal('N', mu=0,sd=5,lower=0)

    trace = pm.sample()

    pm.traceplot(trace)

    plt.savefig("trunc_test")

    plt.close()

I believe that it is because you are giving it a Normal distribution with mean 0 and excluding all the mass below zero, which makes it hard for the sampler.

Try the following:

y = np.random.normal(0.1,1,size=100)

with pm.Model() as model:

    N = pm.TruncatedNormal('N', mu=0,sd=5,lower=0)
    
    y_pred = pm.Normal('y_pred', N, 1, observed = y)

    trace = pm.sample()

    pm.traceplot(trace)

We get divergencies. Notice that the mean is extremely close to the bound. If we get a little further away from those extremes:

y = np.random.normal(0.5,1,size=100)

with pm.Model() as model:

    N = pm.TruncatedNormal('N', mu=0,sd=5,lower=0)
    
    y_pred = pm.Normal('y_pred', N, 1, observed = y)

    trace = pm.sample()

    pm.traceplot(trace)

No divergencies.

Hope it helps!

hi, @luisroque,
Thank you for getting back to me!
I am having trouble accepting that solution, as restricting the truncatednormal upper and lower bounds even more still works as long as i also specify an upper-bound.

as an example, this gives no divergencies:
N = pm.TruncatedNormal('N', mu=0,sd=5,lower=3,upper=8)
while this does:
N = pm.TruncatedNormal('N', mu=0,sd=5,lower=3)

Just wanted to add that i see the same with bounded variables with pm.Bound.
Could this be a bug?

Hmm, you are right and I can’t understand why. Would be good if someone with more experience could help. I will leave some notes from some digging that I did.

From what I see on the code, PyMC3 transforms the variables when you define a TruncatedNormal. I can also see that the logp is increased proportionally in the region that you are considering, which results in a higher density. When you use both lower and upper bounds, naturally you are increasing it from both sides. I would imagine that it could help the sampler when exploring the space.

I can see that if we increase the target_accept with just the lower bound considered we can manage to get no divergencies but our effective sample size is quite low (it is low nevertheless).

I can see that with a very high upper bound the behaviour is the expected, i.e. the same when not defining one. The strange thing is in between, like when using a sigma of 10 and an upper bound of 100. I don’t see why we get no divergencies and a significantly bigger ess.

yes, thank you,
would be great to figure this out!

does anyone have any suggestions on how to understand this? :smiley:

Hi!
I have still not figured this out, and i believe there is a simple solution to the problem. Still hoping for a suggestion on how to explain this!

Thanks

I also get divergence problems when sampling from a HalfNormal

with pm.Model() as m: 
  x = pm.HalfNormal('x', 1) 
  trace = pm.sample() 
  # Divergences

I think NUTS is struggling to sample because the peak of the distribution is at the very boundary of the domain. The problems seem to go away if the mean is enough away from the boundary:

with pm.Model() as m: 
  x = pm.TruncatedNormal('x', lower=0, mu=3, sigma=1) 
  trace = pm.sample() 
  # No Divergences

If you want NUTS to sample you might want to give it an extra legroom by tweaking tune and target_accept:

with pm.Model() as m: 
  x = pm.TruncatedNormal('x', lower=0, mu=0, sd=5) 
  trace = pm.sample(tune=2000, target_accept=.95) 
  # No divergences

If anyone else can chime in and confirm this is reasonable behavior from NUTS, that would be really helpful.

This might be related: Zero-excluding priors are probably a bad idea for hierarchical variance parameters « Statistical Modeling, Causal Inference, and Social Science

1 Like