PyMC3 to PyMC: Alternative to check_parameters and bound?

Hello PyMCers!

TLDR; How should I enforce logic within the logp of a custom distribution without check_parameters or distributions.dist_math.bound?

I’m an ecologist who’s new to PyMC. Ultimately, I would like to fit open capture-recapture models, specifically, Jolly-Seber models, to simulated data. To that end, I’ve been following this helpful notebook. It was written in PyMC3, and it’s been easy to port to PyMC for the most part.

However, I’m having trouble with portion of the notebook where he implements the IncompleteMultinomial (see under the heading The Jolly-Seber Model). The distribution he is creating relates to the number of animals who are unmarked and captured at each occassion:
\{u_1,\dots,u_T\} \sim \text{Mult}(N;\psi_1p,\dots,\psi_Tp), where \psi_1=\beta_0, \psi_{i+1}=\psi_i(1-p)\phi_i + \beta_i, and i indicates the occasion. Altogether, this portion of the likelihood looks like: {N \choose N-u}\left(1 - \sum^{T}_{i=1}\psi_ip\right)^{N-u}\prod^{T}_{i=i}(\psi_ip)^{u_i}, where u=\sum u_i.

As such, the notebook authors defines this portion of the likelihood as a custom discrete distribution, IncompleteMultinomial(pm.Discrete). To do so, he uses pymc3.distributions.dist_math.bound in the logp of the distribution, which enforces various constraints such as pt.all(x >= 0), pt.all(x <= n), pt.sum(x) <= n. However, this function is not available pymc.distributions.dist_math. Alternatively, there appears to be a check_parameters function. However, the docstring for this function states that it, “should not be used to enforce the logic of the logp expression under the normal parameter support.”

My question: what is the PyMC substitute for bound, in this case, if not check_parameters? How would a PyMC expert rewrite this part of the model?

Thank you for your patience! I’m dying to make the switch from R/JAGs–I’ve never liked the ecosystem–to PyMC.
Phil

Usually we combine a switch with check_parameters. Our halfnormal logp looks something like:

def logp(value, sigma):
  res = ...
  res = pm.math.switch(
    value < 0,
    -np.inf,
    res,
  )
  res = check_parameter(res, sigma>0)
  return res

Thank you. That worked perfectly. For posterity, this is my version of the logp from the notebook.

def logp(x, n, p):
    
    x_last = n - x.sum()
    
    # calculate the logp for the observations
    res = factln(n) + pt.sum(x * pt.log(p) - factln(x)) \
            + x_last * pt.log(1 - p.sum()) - factln(x_last)
    
    # ensure that the good conditions are met.
    good_conditions = pt.all(x >= 0) & pt.all(x <= n) & (pt.sum(x) <= n)
    res = pm.math.switch(good_conditions, res, -np.inf)

    return res
2 Likes