Let’s say that we have access to N observations of a phenomenon where for each observation i we observe a defined number m of independant draws of a Negative Binomial whose paremeters are \alpha_i (Gamma distribution shape parameter) and \mu_i (Gamma distribution mean).
The \alpha parameter is shared across every observations thus \forall i, \alpha_i = \alpha
However, \mu_i depends on a latent variable z_i that assigns every observation i to a hidden state among k possible states.
The latent variables z can be seen as indicator variables and can be modeled using a Multinomial \mathcal{Multinomial}(n=1,p_1,...,p_k) where p_1,...,p_k the probabilities of each one of the different outcomes. In this case z_i is a binary vector pf size k with one 1 on the state where observation i belongs.
Alternatively, the latent variables z can also be modeled using a Categorical distribution with probabilities p_1,...,p_k.
Additionnally, negative binomial mean of each state 1,...,k is known
For example, let’s imagine that observation i belongs to the second state, then z_i =2, and this observation is a vector of m independant draws of a Negative Binomial with parameters \alpha and \mu_j = mean of state 2
Given a bunch of observation following this model, we want to infer \alpha, the latent variable of each observation z and the probabilities of each state p_1,...,p_k.
Observations simulation
N, M, max_state = 10, 5, 8
prob = np.random.rand(max_state)
state_prob = prob/sum(prob) #probabilities of each state
expected_mean = np.concatenate((np.array([1]),np.arange(1,max_state)*10)) #mean of each state
alpha_gt = 0.0001 # shape parameter
with pm.Model() as ind_simulation:
w = state_prob
latent_z = pm.Multinomial('z',n=1, p=w, shape=(N,max_state))
state_assignment = latent_z.eval() #For debuging purposes we do an evaluation of the latent variabl
alpha = alpha_gt
mu = np.repeat(pm.math.matmul(state_assignment,expected_mean.reshape(-1,1))[:,0][:, np.newaxis], M, axis=1).eval()
obs_data = pm.NegativeBinomial('ind',alpha=1/alpha, mu=mu)
obs_draw = pm.draw(obs_data)
Here is a heatmap of the simulated data, which looks as expected in terms of size, and values. Each value on one line seems to be generated from the same negative binomial distribution.
Sampling
We can now perform the inference
with pm.Model() as inference_model:
w = pm.Dirichlet("w", a=np.ones(max_state))
latent_z = pm.Multinomial('z',n=1, p=w,shape=(N,max_state))
alpha = pm.Uniform("alpha", lower=0, upper=0.1)
mu = np.repeat(pm.math.matmul(latent_z,expected_mean.reshape(-1,1))[:,0][:, np.newaxis], M, axis=1)
obs_distrib = pm.NegativeBinomial('obs',alpha=1/alpha, mu=mu, observed=obs_draw)
obs_sample = pm.sample()
pm.model_to_graphviz(model=inference_model)
We give the \mu parameters of the Negative Binomial in the most explicit way in the form of a matrix that matches the shape of the observation (N,M) to avoid any ambiguity.
I would have expected to find the z and the observation in a big shared plate of size N and the observation nested in a smaller plate of size M.
Inference results
We can first have a look at the trace
Then, if we look to the posteriors for the latent variables we observe that all the observations were weirdly assigned to the first state
The state probabilities are also off, with the first state having the higher probability
The shape parameter is also reaching the upper bound of the prior, allowing wide distributions
With categorical distribution
I then tried to model the latent variables using categorical distribution resulting in a new inference model
with pm.Model() as inference_model_categorical:
w = pm.Dirichlet("w", a=np.ones(max_state))
latent_z = pm.Categorical('z', p=w,shape=N)
alpha = pm.Uniform("alpha", lower=0, upper=0.1)
mu = pt.shared(expected_mean)[latent_z]
obs_distrib = pm.NegativeBinomial('obs',alpha=1/alpha, mu=mu, observed=obs_draw.T)
obs_sample = pm.sample()
pm.model_to_graphviz(model=inference_model_categorical)
The model is again not displaying nested plates
However, now the inference works fine
Can someone please help me to understand what is happening here ?