Inferring Dirichlet concentration parameter

I am trying to infer the most likely concentration parameter for samples from a Dirichlet distribution but am struggeling to set this up in PyMC3.

My data is

x = np.array([[0.00560805, 0.16788784, 0.06668699, 0.02216397, 0.21703139,
        0.        , 0.04982627, 0.00448644, 0.11845169, 0.21861628,
        0.00616885, 0.12307223],
       [0.19244308, 0.02297949, 0.09517768, 0.03622034, 0.07414341,
        0.19984624, 0.01186158, 0.11298382, 0.043279  , 0.09359883,
        0.08446663, 0.0329999 ],
       [0.22363757, 0.00368887, 0.08686129, 0.0055333 , 0.11712153,
        0.19620162, 0.0165999 , 0.09579527, 0.0055333 , 0.14657483,
        0.08769705, 0.01475547]])

and my model setup is

n = x.shape[0]
K = x.shape[1]

with pm.Model() as model:
    alpha = pm.Multinomial("alpha", n=n, p=np.ones(12), shape=K)
    theta = pm.Dirichlet('theta', a=alpha, shape=K, observed=x)

    trace = pm.sample(draws=1000) 

but the traceplot I get does not make any sense to me.

I am quite new in using PyMC3 and I could not figure out what I am doing wrong here.

Hi Fabian,

  • The concentration parameters of the Dirichlet must be > 0 real numbers, which is not what the Multinomial spits out – it is a generalization of the Binomial, so it will give you vectors of positive integers representing the number of successful trials. As a result, the concentration of a Dirichlet is usually adequately modeled by a Gamma distribution.
  • You don’t need to pass shape to your likelihood: as you give it observed data, it knows what shape it needs.
  • For future reference, your Multinomial is incorrectly parametrized: you’re passing a probability of one to each of the categories. It doesn’t error out because PyMC3 normalizes the p vector under the hood, so it’s interpreted as a uniform prior on the categories, but it’s definitely good to know.

Hope this helps :vulcan_salute:

Hi Alex, thank you for your helpful answer!

I exchanged the Multinomial with a Gamma distribution and chose arbitrary hyperparameters for it:

n = x.shape[0]
K = x.shape[1]

with pm.Model() as model:
    alpha = 5.
    beta = 1.
    
    gamma = pm.Gamma("gamma", alpha=alpha, beta=beta)
    theta = pm.Dirichlet('theta', a=gamma, observed=x) 

    trace = pm.sample(draws=1000) 

but I get a TypeError:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-eb78b2b21632> in <module>
      7 
      8     gamma = pm.Gamma("gamma", alpha=alpha, beta=beta)
----> 9     theta = pm.Dirichlet('theta', a=gamma, observed=x)
     10 
     11     trace = pm.sample(draws=1000)

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in __new__(cls, name, *args, **kwargs)
     44                 raise TypeError("observed needs to be data but got: {}".format(type(data)))
     45             total_size = kwargs.pop('total_size', None)
---> 46             dist = cls.dist(*args, **kwargs)
     47             return model.Var(name, dist, data, total_size)
     48         else:

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in dist(cls, *args, **kwargs)
     55     def dist(cls, *args, **kwargs):
     56         dist = object.__new__(cls)
---> 57         dist.__init__(*args, **kwargs)
     58         return dist
     59 

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\multivariate.py in __init__(self, a, transform, *args, **kwargs)
    478 
    479         kwargs.setdefault("shape", shape)
--> 480         super().__init__(transform=transform, *args, **kwargs)
    481 
    482         self.size_prefix = tuple(self.shape[:-1])

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in __init__(self, shape, dtype, defaults, *args, **kwargs)
    191         if dtype is None:
    192             dtype = theano.config.floatX
--> 193         super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
    194 
    195 

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in __init__(self, shape, dtype, testval, defaults, transform, broadcastable)
     61                  transform=None, broadcastable=None):
     62         self.shape = np.atleast_1d(shape)
---> 63         if False in (np.floor(self.shape) == self.shape):
     64             raise TypeError("Expected int elements in shape")
     65         self.dtype = dtype

TypeError: must be real number, not TensorVariable

I could get rid of it by passing shape=K to both the Gamma and the Dirichlet variables, but when I want to run the code, it crashes (I am working in a Jupyter notebook and I lose the server connection).
Apparently, the shape parameter is needed or I am doing something else wrong.

Yeah, you definitely need to pass the shape arg to the Gamma, otherwise it can’t infer the length of the vector you need – i.e as much as there are categories in your Dirichlet; shape=K here I think.

And yeah maybe you’ll need to also specify the shape to the Dirichlet, even though it’s a likelihood (my bad here, I forgot the Dirichlet is a little special), but try without first.

I am actually not quite sure how to assign the shapes.
The shape of my data is an (n,K)-array where n is the number of samples and K is the dimension of each sample. So, the Gamma distribution should have shape=K and the Dirichlet as well, right?

But unfortunately, my notebook crashes. Could this be a Jupyter problem?

The Dirichlet returns a vector of probabilities of the categories: [p_1, p_2, ..., p_k], with \sum_{i=1}^{k} p_k = 1. This is a generalization of the Beta distribution to more than two categories.

What do you mean by the notebook crashes? What’s the error you’re getting?

Hi, yes that’s why I use the Dirichlet to model my data: each sample is a distribution over K items but I do not now much about the underlying generative process. So I thought that using a Dirichlet would be appropriate and want to look at changing concentration parameters depending on subsets of the whole sample.

Initiating the model seems to work, I don’t get any error and can call both the gamma and theta variables. But when I try to run trace = pm.sample(draws=1000) (of course with in the context manager), my notebook server crashes with the error message:

Server Connection Error
A connection to the Jupyter server could not be established. JupyterLab will continue trying to reconnect. Check your network connection or Jupyter server configuration.

I am sure that calling sample is the cause for this but don’t know what I’d have to change.

When I reproduced the code in the ipython CLI, I got the following error message:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [gamma]
Sampling 4 chains, 0 divergences:   0%|            | 0/6000 [00:00<?, ?draws/s]T
raceback (most recent call last):
  File "<string>", line 1, in <module>
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spa
wn_main
    exitcode = _main(fd)
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 115, in _ma
in
    self = reduction.pickle.load(from_parent)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\compile\function_modul
e.py", line 1082, in _constructor_Function
    f = maker.create(input_storage, trustme=True)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\compile\function_modul
e.py", line 1715, in create
    input_storage=input_storage_lists, storage_map=storage_map)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\link.py", line 699
, in make_thunk
    storage_map=storage_map)[:3]
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\vm.py", line 1091,
 in make_all
    impl=impl))
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\op.py", line 955,
in make_thunk
    no_recycling)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\op.py", line 858,
in make_c_thunk
    output_storage=node_output_storage)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 1217,
 in make_thunk
    keep_lock=keep_lock)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 1157,
 in __compile__
    keep_lock=keep_lock)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 1623,
 in cthunk_factory
    module = get_module_cache().module_from_key(
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 48, i
n get_module_cache
    return cmodule.get_module_cache(config.compiledir, init_args=init_args)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cmodule.py", line
1587, in get_module_cache
    _module_cache = ModuleCache(dirname, **init_args)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cmodule.py", line
703, in __init__
    self.refresh()
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cmodule.py", line
826, in refresh
    key_data = pickle.load(f)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\tensor\elemwise.py", l
ine 412, in __setstate__
    super(Elemwise, self).__setstate__(d)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\op.py", line 1160,
 in __setstate__
    self.__dict__.update(d)
KeyboardInterrupt

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in _mp_sample(draws
, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, mode
l, **kwargs)
   1058             with sampler:
-> 1059                 for draw in sampler:
   1060                     trace = traces[draw.chain - chain]

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\parallel_sampling.py in __iter_
_(self)
    393         while self._active:
--> 394             draw = ProcessAdapter.recv_draw(self._active)
    395             proc, is_last, draw, tuning, stats, warns = draw

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\parallel_sampling.py in recv_dr
aw(processes, timeout)
    283         pipes = [proc._msg_pipe for proc in processes]
--> 284         ready = multiprocessing.connection.wait(pipes)
    285         if not ready:

C:\ProgramData\Anaconda3\lib\multiprocessing\connection.py in wait(object_list,
timeout)
    858
--> 859             ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), t
imeout)
    860         finally:

C:\ProgramData\Anaconda3\lib\multiprocessing\connection.py in _exhaustive_wait(h
andles, timeout)
    790         while L:
--> 791             res = _winapi.WaitForMultipleObjects(L, False, timeout)
    792             if res == WAIT_TIMEOUT:

KeyboardInterrupt:

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-14-1a746dc8612f> in <module>
      2     gamma = pm.Gamma("gamma", alpha=1., beta=1., shape=K)
      3     theta = pm.Dirichlet("theta", a=gamma, shape=K, observed=x_)
----> 4     trace = pm.sample(draws=1000)
      5

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in sample(draws, st
ep, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, mod
el, random_seed, discard_tuned_samples, compute_convergence_checks, **kwargs)
    467         _print_step_hierarchy(step)
    468         try:
--> 469             trace = _mp_sample(**sample_args)
    470         except pickle.PickleError:
    471             _log.warning("Could not pickle model, sampling singlethreade
d.")

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in _mp_sample(draws
, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, mode
l, **kwargs)
   1078         return MultiTrace(traces)
   1079     except KeyboardInterrupt:
-> 1080         traces, length = _choose_chains(traces, tune)
   1081         return MultiTrace(traces)[:length]
   1082     finally:

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in _choose_chains(t
races, tune)
   1094     lengths = [max(0, len(trace) - tune) for trace in traces]
   1095     if not sum(lengths):
-> 1096         raise ValueError("Not enough samples to build a trace.")
   1097
   1098     idxs = np.argsort(lengths)[::-1]

ValueError: Not enough samples to build a trace.

The sample that I passed to observed has shape=(12625,12) and I would assume that the sample size is not really the problem. To test that PyMC3 does not interprete x as a single sample of shape (12625,12), I passed shape=x.shape to the initialization of theta but got the same error.

Ooh, seems like you’re on Windows, so it could be a nasty multiprocessing issue…
Try doing pm.sample(cores=1)

I tried that and also using Ubuntu but now I get SamplingError: Bad initial energy. Sorry for the back and forth, but I have no idea what could cause this.

Update: Looking at this post, I got it to run on Windows WSL (Ubuntu) with

with pm.Model() as model:
    gamma = pm.Gamma("gamma", alpha=1., beta=1., shape=K)
    theta = pm.Dirichlet("theta", a=gamma, shape=K, observed=x)

    trace = pm.sample(draws=1000, cores=2, init="adapt_diag", tune=1000)

Everything seems to work fine for the first 2000 iterations (first half), but for the second half I get only divergences and the output reads The chain contains only diverging samples. The model is probably misspecified.

Yeah, Bad initial energy usually comes from too vague priors, messed up initialization of NUTS and / or missing values somewhere. Good job on sorting this out :ok_hand:

I’d probably do some prior and posterior predictive checks to try and understand why your model is misspecified. Here’s an example of how to do it.

Thanks a lot for your help!