How to implement Zero-Inflated Multinomial?

In the end, this model took 3 hours and a half to sample :laughing:
There are no divergences, but I’m trying to implement a non-centered version of the MvNormal – I suspect it to make sampling harder, as the intercept-only model samples in 2 minutes. That would mean that the line:

# varying effects:
ab_cluster = pm.MvNormal(f"ab_cluster{p}", mu=ab, chol=chol, shape=(Nclusters, 2))

becomes:

z = pm.Normal(f"z_{p}", 0., 1., shape=(Nclusters, 2))
vals_raw = ab + z
ab_cluster = pm.Deterministic(f"ab_cluster{p}", tt.dot(chol, vals_raw.T).T)

Does this seem right to you? And are there other optimizations I could do to speed-up sampling in your opinion?
(I can share how I simulate the data if useful – didn’t do it here because it’s a big chunk of code not related to PyMC)

1 Like