Dirichlet Distributions and Batching in pymc3 3.9+

Hi Everyone,

tl;dr
Drawing from a Dirichlet distribution with shape (1,N):

with pm.Model() as model:
    α = pm.Dirichlet(r'α', a=5*np.ones((1, N)), shape=(1, N))

appears to work in pyMC3 < 3.9 but fails in more recent versions with the error:

Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:
α_stickbreaking__   -inf
CRITICAL:pymc3:Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:
α_stickbreaking__   -inf
pymc3.parallel_sampling.RemoteTraceback: 

Context

I’m looking at data that appears to be Poisson distributed according to an observed rate Λ(t) (with T observations over time) and has been sharded into N partitions. I’d like to model the relative share that each partition has of the overall Poisson count.

To do that I’m running N Poisson regressions (one for each partition) with a rate given by Λ*α where α is the share of the full rate ‘owned’ by that partition. The α’s I’m drawing from a Dirichlet distribution.

Here’s a sketch of the model:

with pm.Model() as model:
    Λ = pm.Data('Λ', Lambda)    # Total observed rate per time step. Lambda has shape  (T,1)

    # Model for share of underlying Poisson distribution
    α = pm.Dirichlet(r'α', a=c*np.ones((1,N)), shape=(1,N)) 
    
    # Poisson rate per partition
    λa = pm.Deterministic('λa', pm.math.dot(Λ,α))  # Should have shape (T,N)
    
    # Observed counts per partition
    â = pm.Data('â', a_hat)                        # â has shape (T,N)
    a = pm.Poisson("a", mu = λa, shape=a_hat.shape, observed=â)

Having the shape of the Dirichlet being (1,N) is to ensure that it broadcasts against the batch of different observations Λ with shape (T,1). Now the above model seems to work great in versions of pyMC < 3.9 but in 3.9 and above I’m seeing the above error. I’m not sure if I understand why a version change should induce weird values in the logp and couldn’t see anything in the release notes to suggest so. Is there a better way I can be going about this?

Thanks!

It seems to be a bug - could you raise an issue on Github?

Sure, will do :+1:

1 Like