TruncatedNormal logp returning all -inf?

I am trying to build a mixture model that contains TruncatedNormal components, but was having trouble optimizing and sampling. I tried evaluating the model logp at a few test values and traced back some oddities to find that TruncatedNormal density is evaluating to zero at any input value I pass in. For example:

with pm.Model() as model:    
    mu = pm.Uniform("mu", -5, 2)
    ln_std = pm.Uniform("ln_std", -8, 1)
    std = pm.Deterministic("std", pm.math.exp(ln_std))

    dist1 = pm.TruncatedNormal.dist(mu=mu, sigma=std, lower=0, upper=None)
    dist2 = pm.Normal.dist(mu=mu, sigma=std)
    x_grid = np.linspace(-5, 10, 1024)
    dist1_logp = pm.Deterministic('dist1', pm.logp(dist1, x_grid))
    dist2_logp = pm.Deterministic('dist2', pm.logp(dist2, x_grid))

func1 = model.compile_fn(dist1_logp, inputs=[mu, std])
func2 = model.compile_fn(dist2_logp, inputs=[mu, std])

init_p = {'mu': 1, 'std': 0.1}


[-inf -inf -inf ... -inf -inf -inf]
[-1798.61635344 -1789.8294493  -1781.06404481 ... -4022.26639085
 -4035.43062232 -4048.61635344]

I think I’m misunderstanding something here - any ideas? Or is this a bug?

Thanks for reading!

It seems that the Truncated logp switches all values to negative, if any of them is invalid:

import pymc as pm

with pm.Model() as m:
    x = pm.TruncatedNormal("x", mu=0, sigma=1, lower=0, upper=None, transform=None, shape=2)
# -1 is invalid, but 1 is fine
print(m.compile_logp(sum=False)({"x": [-1, 1]}))  # [array([-inf, -inf])]

It should return -inf for the first, and a finite value for the second

Ah yes, this looks like it - in my example, if I change x_grid = np.linspace(1, 10, 1024) (i.e. so that all values are above the lower value) then the values are finite as expected. OK I think this is a bug, so will make an issue - thanks for helping debug this!

1 Like

The solution is to use a switch for the value, instead of passing the bounds into check_parameters as we do here: pymc/ at cadff75cc3787b3d98e8a77f659c1cbb42de63bd · pymc-devs/pymc · GitHub

Usually, we do it like this:

1 Like

That makes sense - I was looking at the source for TruncatedNormal and was confused about how it was even enforcing the bounds. I’ll have a go at a PR!

1 Like

For completeness, see: Fix bug in which TruncatedNormal returns -inf for all values if any value is out of bounds by adrn · Pull Request #6128 · pymc-devs/pymc · GitHub

1 Like