Confusion on the use of shape parameter in multinomial likelihood

I am trying to model a data with 6 repeats (number of rows) and 40 categories (number of columns) using multinomial distribution (see below for the data and code). I tried to model this data using two different shape arguments to the likelihood function which I thought would implicitly be broadcasted to the same thing but the results turned out be completely different for low count observations:

import pymc3 as pm
import arviz as az
import numpy as np

default_sample_parameters = {
    'draws':1000,
    'tune':1000,
    'chains':6,
    'cores':6,
    'return_inferencedata':True,
    'progressbar':True,
    'target_accept':0.95
    }

def fit_multinomial(observed, sample_parameters=None, hdi_prob=0.95,
                    lshape=None):

  if sample_parameters is None:
    sample_parameters = {}

  sample_parameters = dict(default_sample_parameters, **sample_parameters)

  nrepeats, ncategories = observed.shape

  if lshape is None:
    lshape = (nrepeats)

  with pm.Model() as model:
      p0 = pm.Dirichlet("p0", a=np.ones(ncategories))[None,:]

      n0 = np.sum(observed, axis=1)[:,None]
      pm.Multinomial(f"initial_sample", n=n0, p=p0, shape=lshape,
                      observed=observed)
      trace = pm.sample(**sample_parameters)

  with model:
    log_p0 = pm.Deterministic("log_p0", pm.math.log(p0))
    ppt = pm.sample_posterior_predictive(trace, var_names=["log_p0"],
                                         keep_size=True)

  summary = az.summary(ppt, var_names=["log_p0"],
                       hdi_prob=hdi_prob, skipna=True)


  return summary


observed = np.array(
      [[388156,  24806,     35,   6540,   4952,    186,   2381,   5482,
            27,    112,    172,    196,  10050,   9847,    284,  14606,
           322,    758,    330,    221,     12,   1057,     12,      5,
           112,    298,    563,     78,    480,     17,    782,    139,
            20,   4552,    940,   4689,   7462,    596,    804,   7919],
       [375116,  25604,      7,   8479,   3496,    116,   2054,   4539,
           338,    124,    171,    511,  12027,  12498,    275,   8833,
           816,    781,    647,    347,      5,    633,     13,     11,
           119,    216,    882,    149,    611,     10,    795,    126,
             3,   4091,   1315,   4456,   8873,    531,    942,  19440],
       [378762,  17261,     37,   8731,   4920,    273,   2228,   4673,
           258,     93,    268,    820,  13556,  13737,    116,  14917,
           698,    625,    451,    328,     13,    974,     14,     10,
           193,    119,    789,    109,    719,     12,   1187,     86,
            24,   5416,    980,   5262,   7077,    544,    610,  13110],
       [378996,  17881,     53,   8014,   4439,    218,   2744,   4601,
             0,    120,    264,    411,  11867,  10576,    547,  11806,
          1073,   1190,    788,    430,     19,    766,     24,     15,
           118,    202,    610,    184,    591,     16,    922,     95,
            38,   4625,   1064,   5943,   9623,    632,    844,  17651],
       [372024,  21671,     29,   6377,   3952,    253,   3054,   4753,
           156,    170,    148,    843,   5277,  13510,    469,  13860,
           528,    822,    510,    375,     15,   1051,      9,      8,
           196,    234,   1270,    259,    559,     26,   1008,     93,
             7,   6050,   1202,   5209,  11087,    396,   1085,  21455],
       [373764,  19591,     34,   9503,   4824,    209,   2843,   7107,
           114,    153,    254,    596,   9898,  10053,    489,  14993,
           797,   1021,    710,    303,     11,    781,     28,      2,
            64,    235,   1200,    135,    770,     16,    996,    208,
            28,   4613,   1151,   5190,  10907,    623,    832,  14954]])


nrepeats, ncategories = observed.shape
lshape = (nrepeats)
observed_proportions = observed/np.sum(observed, axis=1)[:, None]

summary=\
  fit_multinomial(observed, lshape=lshape)

The first shape argument I supplied was lshape = (nrepeats, ncategories) and the second one I tried was (nrepeats). Initially I expected them to be the same with implicit broadcasting but they came out quite different for low count observations. The data I am supplying above does not fit this model exactly, I know that. So I was not expecting the model to work very well. But somehow when lshape=(nrepeats) it works quite well (when I scatter plot fitted log_p0 vs np.log(np.nanmean(observed_proportions,axis=0)) whenever well defined). So this made me realize I may not be understanding how the shape parameter is handled and I dont want to take lshape=nrepeats without understanding why it comes out better. I still suspect the first version is the more correct one because I would expect some discrepancy for lower count observations. So my question is how is each shape argument handled in the likelihood? I have also seen that with stupd shape arguments like (nrepeats, nrepeats) the model runs eventhough with quite bad results for r values and such. So this makes me think that I don’t understand at all how the shape argument is handled internally. Thanks

ps: In the following page, in the first model, they also use something that is anologous to (nrepeats, ncategories):
dirichlet mixtures of multinomials

This may help: Distribution Dimensionality β€” PyMC 5.5.0 documentation

Specially the multivariate examples

Thanks for the reply. I noticed that these examples are actually in the latest pymc, I was using pymc3 so I updated to pymc and found that the when I supply the same shape arguments, in some instances it now complains about shape mismatch. So I will investigate a bit more and get back.

Yeah shapes are more strict (and predictable) in the newer versions

I can confirm (sorry for forgetting to reply) that they do behave as I expected when I used pymc instead of pymc3.

1 Like