Extreme sampling slowness with pm.Dirichlet

Hi,

I think there might be something wrong with pm.Dirichlet, although I’d love for some to tell me I’m doing something wrong. Consider this MWE:

n = 100
h = 61
alpha = 2
beta = 2

with pm.Model(check_bounds=False) as model:
  p = pm.Beta("p", alpha=alpha, beta=beta, shape=(n, h))
  b = pm.Dirichlet("b", a=p, shape=(2, n, h))
  idata = pm.sample(nuts_sampler="numpyro", chains=2)

Replace b with say, pm.Exponential and sampling takes about one minute, 30 seconds. With pm.Dirichlet we’re at 28 minutes, 39 seconds. numpyro gives a bit of a boost compared to the built-in PyMC sampler but not much, and nutpie with pm.Dirichlet crashes completely (which may or may not be a related issue).

Is this the expected behavior, as in, sampling from a Dirichlet distribution takes a long time and we just have to live with it? Or, am I doing something wrong, or, considering nutpie doesn’t work with pm.Dirichlet at all, is this a bug?

This is all running on an Apple M1 Pro CPU with 32gb of RAM and Ventura 13.6. Relevant package versions:
Python 3.11.6
pymc 5.9.0
pytensor 2.17.1
libblas 3.9.0 with accelerate
numpyro 0.13.2

Please let me know what other information I can provide to help diagnose this. I really appreciate all the help this forum can provide and the great work that goes into PyMC :slight_smile:

Thanks.

You are sampling ~12k parameters in case that helps grounding expectations.

Then there are 2 questions:

  1. How expensive is the gradient and logp evaluation. You can try to time these functions: I wouldn’t be surprised if there’s a factor of 18x there, which would explain the total difference.

  2. How easy it is for the sampler to take draws from each variable. It could be that for the Dirichlet, NUTS needs twice as many logp/dlogp evaluations, so now the actual functions would only need to be 9x slower to explain the total difference.

Point 2 is tricky, because a more complex model/function may be easier to sample (e.g., agrees more smoothly with the data) and require less evaluations. Also, these differences could also go away with better priors, predictors, (and the actual likelihood!).

That means the differences you find in this isolated example may be completely unrelated to what is causing your model of interest to sample slowly.

Anyway 1. is a good starting point to check if the difference is reasonable. My gut feeling says it is reasonable

Beta is a very odd prior to put on the concentration parameters. Unless you have some very peculiar data, you probably want to have reasonable probability for values greater than one. Have a look at what happens when you constrain the alpha and beta parameters of a beta distribution to values less than one – it pushes the probability out to zero and one. With a Dirichlet, you are doing something analogous, except in more dimensions.

2 Likes

Hi @ricardoV94 and @fonnesbeck , thanks for your replies. Apologies for selecting the beta prior in my example. My actual model of interest uses gamma, but modifying my MWE here to use pm.Gamma in place of pm.Beta makes little-to-no difference in terms of speed. And since this is all part of a larger modeling goal, I’ve swapped out the Dirichlet in the model for beta and the result is the same, all else being the same in the model, a significant speed increase (but horrible inferences). It’s a gamma → Dirichlet → multinomial regression, so the Dirichlet makes sense. But I guess there’s not much else I can do other than reduce the number of draws and tuning examples significantly :confused:

It’s not merely the distribution type that matters but their values as well. If your gamma has a lot of mass on values <1 the issue @fonnesbeck mentioned would still be true

If you have a Dirichlet → Multinomial you may also consider (if you haven’t already) using a DirichletMultinomial to marginalize away the Dirichlet variables