For models with 3+ categorical variables and potentially missing data, the sampling time required for the data-complete model (no latent discrete variables) and the data-missing model (many missing discrete variables), the latter takes considerably longer to sample (~5-10x longer) and it appears that the cost of the CategoricalGibbsMetropolis sampler is the reason why. For context, here’s a minimal model that uses both the continuous and discrete samplers:
import pymc3 as pm
n = 1000
with pm.Model() as model:
p = pm.Dirichlet('x', np.ones(n))
y = pm.Categorical('y',p=p)
trace = pm.sample()
I’ve spent some time looking through the PyMC implementation and am a little unsure as to what pieces take the most time to execute and what might be optimized. I was wondering what profiling strategy might work to help shine light on this. I’ve been working in Jupyter, so any extensions in that are fair game.
EDIT: I should also note that this does not appear to be a problem stemming from the posterior geometry under the alternating Gibbs / NUTS scheme. It is still the case when using HMC with a fixed number of steps.