Prior predictive samples with multidimensional parameters

I am trying to fit a model similar to Latent Dirichlet Allocation (LDA), but allowing overdispersion. This is my model:

with pm.Model() as model:
    k = 3
    n, p = genus_counts.shape
    profiles = pm.Dirichlet("profiles", np.ones((k, p)), shape=(k, p), transform=t_stick_breaking(1e-9))
    weights = pm.Dirichlet("weights", np.ones((n, k)), shape=(n, k), transform=t_stick_breaking(1e-9))
    apparent_abundance =  pm.Deterministic("apparent_abundance", pm.math.dot(weights,profiles))
    overdispersion = pm.Exponential("overdispersion", 1)
    read_counts = pm.NegativeBinomial("read_counts", genus_counts.values.sum(axis=1)[:,None]*apparent_abundance, 1/overdispersion, shape=(n,p), 
                                      observed=genus_counts.values)

Where genus_counts is a Pandas DataFrame containing counts.

Then, I used pm.sample_prior_predictive to check if my prior distribution was OK (I called it with no parameters, within the model context manager). My overdispersion parameter had shape (500,), as expected, since I took 500 draws from the prior predictive. The parameter read_counts had shape (500, n, p), which also makes sense.

But profiles had shape (k, p), weights had shape (n,k) and apparent_abundance had shape (n,p), and I don’t undestand: why did the sampler return only one matrix for apparent abundance? I believe this is some issue with multidimensional distributions and shape parameters, but read_counts was sampled with no problem (maybe because it is tagged as observed?).

Thanks in advance!

The issue is specific to the Dirichlet distribution. It doesn’t handle multidimensional shapes terribly well. This is a bug which is partially referenced at this issue. I’ve hacked out a workaround using the fact that a Dirichlet can be represented as independent Gamma RVs normalized by their sum. You can find that at this gist.

For a minimal reproducible example, see below.

import numpy as np
import pymc3 as pm

k = 2
p = 2
with pm.Model() as model:
    x = pm.Dirichlet('x', np.ones(p), shape=(p))
    y = pm.Dirichlet('y', np.ones([k,p]), shape=(k,p))

    trace = pm.sample_prior_predictive()
    
_ = [print(var, t.shape) for var, t in trace.items()]
1 Like