How to improve this sort of model so that it does not get stuck/diverge?


This is a very open-ended question, but it is a problem I find commonly when fitting models with discrete parameters. Particular when the observed data can rule out a significant portion of the prior as being impossible (e.g., if I observe 5 cars, then it is impossible that there are a total of 4 or less cars). Hopefully your advice on how to tackle this sort of problem may also be useful for others.

In particular, I am trying to adapt the model described in Rasmus Bååth blog. The idea is that you draw n socks and observe k pairs. What reasonable inferences can you make about the total number of socks. This is slightly complicated by the fact that some socks may not have a corresponding pair.

I copied the prior specification without changes:

with pm.Model() as m:
    prior_mu = 30
    prior_sd = 15
    prior_size = -prior_mu**2 / (prior_mu - prior_sd ** 2)
    n_socks = pm.NegativeBinomial('n_socks', prior_mu, prior_size)

    prop_pairs = pm.Beta('prop_pairs', 15, 2)

    n_pairs = pm.Deterministic('n_pairs', pm.math.floor(n_socks // 2 * prop_pairs))
    n_odd = pm.Deterministic('n_odd', n_socks - (n_pairs * 2))

The likelihood function was tricky but with some help from StackExchange I was able to get it:

from pymc3.distributions.dist_math import factln

def th_binomln(n, k):
    # Binomial coefficient    
    return factln(n) - (factln(k) + factln(n-k))

def th_binom(n, k):
    # Exponent of the binomial coefficient
    return tt.exp(th_binomln(n, k))

def th_log_prob_pairs(m, l, n, k):
    # Probability of observing k pairs
    # Given n draws, m pairs and l singletons (missing pairs)
    ntotal = th_binomln(2*m + l, n)
    npairs = th_binomln(m, k)

    # Loop through all the possible values of drawn singletons
    # to calculate the probability of the observed missing pairs
    nsingles, _ = theano.scan(
        lambda j, m, n, k: th_binom(m - k, n - 2*k - j) * 2 ** (n - 2*k - j) * th_binom(l, j),
        sequences = [tt.arange(0, tt.min([l+1, n -2*k +1]))],
        non_sequences=[m, n, k]

    nsingles = tt.log(tt.sum(nsingles))
    result = npairs + nsingles - ntotal
    # Return result or -inf for impossible combinations
    # (m < k; 2*m+l <n; n/2 < k)
    return tt.switch(m >= k, tt.switch(2*m + l >= n, tt.switch(n/2 >= k, result, -np.inf), -np.inf), -np.inf)

I then add a potential term for the likelihood:

with m:
    pm.Potential('llik', tt.sum([th_log_prob_pairs(n_pairs, n_odd, 11, obs) for obs in data]))

This model produces similar estimates to the ones described in the blog, where the data is simply [0] (i.e., zero pairs were observed from a single draw of 11 socks). However it shows a few divergences when the sampler is around small values for n_socks:

With more data (and still 11 observations), these divergences disappear.
Btw, this is how I generate fake data:

def simulate_data(m, l, n):
    assert n <= m*2 + l

    socks = list(range(m)) * 2 + list(range(-l, 0))
    picked_socks = np.random.choice(socks, n, replace=False)

    obs_pairs = picked_socks.size - np.unique(picked_socks).size
    return obs_pairs

However, if I increase the number of observations (e.g., 21), I start having problems again. For once, I have to manually specify a higher testvalue for n_socks to avoid bad initial energy. I tried something like this:

# Set testval to minimun plausible (or else pymc gets stuck)
# Assuming no missing pairs
min_obs = min(data) * 2
min_plausible_value = (n_draws - min_obs)*2 + min_obs + 1
n_socks = pm.NegativeBinomial('n_socks', prior_mu, prior_size, testval=min_plausible_value)

This way the sampler is able to start. However I still find a lot of divergences. I think this has to do with the sampler trying to jump to areas of the prior that are impossible given the data.

The only solution I can think of, is to narrow down the prior to those values that are possible, but this makes me a bit sad. I feel it goes against the philosophy that you should specify your prior as being your best guess for the true parameters before you observe any data at all (including how much data you are going to collect).

Do you have any suggestion of how I could improve this sort of model? Are the issues I am finding unavoidable?

Thanks for you attention and help!