# Using `switch` with `MVNormal` gives unusual results

Expected behaviour:

In the univariate case (using pm.Normal), the switch statement works as expected:

``````with pm.Model() as example:

theta = pm.Normal("theta", 0,0.1)
constraint = at.lt(theta, 10) ## this should (almost) always be True
x = pm.math.switch(
constraint,
at.as_tensor_variable(np.array([-np.inf,-np.inf], "float64")),
np.array([0,1])
)
print(constraint.eval())
print(x.eval())

llh = pm.math.switch(
constraint,
at.as_tensor_variable(np.array(-np.inf, "float64")),
pm.math.sum(pm.logp(pm.Normal.dist(mu=np.zeros(2), sigma=np.ones(2)), x))  #(Univariate case)
#pm.logp(pm.MvNormal.dist(mu=np.zeros(2), cov=np.diag(np.ones(2))), x)        # (*) (Multivariate case)
)
pm.Potential("llh", llh)
``````

Then `llh.eval()` correctly returns `-inf`.

In the multivariate case (uncomment line with (*)), the model compiles fine, but `llh.eval()` raises an error due to the `inf` values in the `pm.logp`. But these should not be evalulated anyway due to the switch statement.

Switch always evaluates both branches, regardless of the value of `constraint`. See here for details (section ifelse vs switch).

You could try using `ifelse` in place of switch, but it might be better to just assign a large finite value like `-1e6` inside the switch to avoid the `inf`.

1 Like

Ahh yes, I was aware that switch evaluates both branches, but didn’t think it was a problem since the univariate case was returning `inf`s with no problems. I used the `switch` statement because I need to apply over a long list of values and I believe `switch` is a lot faster since it vectorises the operation.

I realise that actually the problem arises because `pm.Normal.dist` accepts being evaluated at `inf`, whereas the multivariate case throws up errors. Do you think it is worth allowing for `pm.MvNormal.dist` to be evaluated at `inf`? It’s a sure fire way of making the sampler reject.

I will use your suggestion of a large finite value for now.
Thanks!

I looked a bit more carefully at the problem, and I think you might be right, but I’m also a bit confused.

I’m confused because you set the value of `x`, the observation, to be `-inf`, as well as the logp of the observation. It should suffice to just set the logp to -inf and leave the observation alone. For example, this code works fine, and will always reject values of theta below 10:

``````    llh = pm.math.switch(
constraint,
at.as_tensor_variable(np.array(-np.inf, "float64")),
pm.logp(pm.MvNormal.dist(mu=np.zeros(2), cov=np.diag(np.ones(2))), at.full((2, ), theta))
)
pm.Potential("llh", llh)
``````

You might be right that perhaps the logp function for the MvNormal should map values of `-inf` to logp of `-inf`. The only reason it doesn’t now is because there is no check for infinity in this function. There is one check to ensure that the covariance matrix is valid, but none that the data (or the mean vector) are finite.

Personally I think it’s probably OK that the program fails loudly when you try to evaluate the logp of infinity, but it’s odd that the univariate distribution doesn’t.

1 Like

Hi Jesse,

Yes, I am doubling up on the constraint there. I wasn’t sure which would result in the most efficient inference, but since the `llh` will have a step gradient at the constraint then that should suffice. In my actual case I actually have two separate constraints, one which affects the data values `x` (and through it the `llh`), and one which directly affects the `llh` (due to parameter constraints). I suppose I could combine them all into the llh via:

``````llh = pm.math.switch(
at.gt(constraint_data+constraint_params, 0),
at.as_tensor_variable(np.array(-np.inf, "float64")),
pm.logp(pm.MvNormal.dist(mu=np.zeros(2), cov=np.diag(np.ones(2))), x)
)
``````

This way I don’t have to tamper with the logp observations.

1 Like

I don’t know why a model constraint would ever alter the observed data (it’s already been observed after all), so I would always have `x` enter the likelihood as-is.

Another alternative for handing the two separate constraints is to use `at.or_` and/or `at.and_` to check the two constraints separately in your switch. There might be bugs that arise from adding them together if they mutually off-set in just the right way (but also maybe that’s impossible and I just don’t know it).

2 Likes