Inferring Dirichlet concentration parameter

I am trying to infer the most likely concentration parameter for samples from a Dirichlet distribution but am struggeling to set this up in PyMC3.

My data is

x = np.array([[0.00560805, 0.16788784, 0.06668699, 0.02216397, 0.21703139,
        0.        , 0.04982627, 0.00448644, 0.11845169, 0.21861628,
        0.00616885, 0.12307223],
       [0.19244308, 0.02297949, 0.09517768, 0.03622034, 0.07414341,
        0.19984624, 0.01186158, 0.11298382, 0.043279  , 0.09359883,
        0.08446663, 0.0329999 ],
       [0.22363757, 0.00368887, 0.08686129, 0.0055333 , 0.11712153,
        0.19620162, 0.0165999 , 0.09579527, 0.0055333 , 0.14657483,
        0.08769705, 0.01475547]])

and my model setup is

n = x.shape[0]
K = x.shape[1]

with pm.Model() as model:
    alpha = pm.Multinomial("alpha", n=n, p=np.ones(12), shape=K)
    theta = pm.Dirichlet('theta', a=alpha, shape=K, observed=x)

    trace = pm.sample(draws=1000) 

but the traceplot I get does not make any sense to me.

I am quite new in using PyMC3 and I could not figure out what I am doing wrong here.

Hi Fabian,

  • The concentration parameters of the Dirichlet must be > 0 real numbers, which is not what the Multinomial spits out – it is a generalization of the Binomial, so it will give you vectors of positive integers representing the number of successful trials. As a result, the concentration of a Dirichlet is usually adequately modeled by a Gamma distribution.
  • You don’t need to pass shape to your likelihood: as you give it observed data, it knows what shape it needs.
  • For future reference, your Multinomial is incorrectly parametrized: you’re passing a probability of one to each of the categories. It doesn’t error out because PyMC3 normalizes the p vector under the hood, so it’s interpreted as a uniform prior on the categories, but it’s definitely good to know.

Hope this helps :vulcan_salute:

1 Like

Hi Alex, thank you for your helpful answer!

I exchanged the Multinomial with a Gamma distribution and chose arbitrary hyperparameters for it:

n = x.shape[0]
K = x.shape[1]

with pm.Model() as model:
    alpha = 5.
    beta = 1.
    
    gamma = pm.Gamma("gamma", alpha=alpha, beta=beta)
    theta = pm.Dirichlet('theta', a=gamma, observed=x) 

    trace = pm.sample(draws=1000) 

but I get a TypeError:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-eb78b2b21632> in <module>
      7 
      8     gamma = pm.Gamma("gamma", alpha=alpha, beta=beta)
----> 9     theta = pm.Dirichlet('theta', a=gamma, observed=x)
     10 
     11     trace = pm.sample(draws=1000)

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in __new__(cls, name, *args, **kwargs)
     44                 raise TypeError("observed needs to be data but got: {}".format(type(data)))
     45             total_size = kwargs.pop('total_size', None)
---> 46             dist = cls.dist(*args, **kwargs)
     47             return model.Var(name, dist, data, total_size)
     48         else:

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in dist(cls, *args, **kwargs)
     55     def dist(cls, *args, **kwargs):
     56         dist = object.__new__(cls)
---> 57         dist.__init__(*args, **kwargs)
     58         return dist
     59 

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\multivariate.py in __init__(self, a, transform, *args, **kwargs)
    478 
    479         kwargs.setdefault("shape", shape)
--> 480         super().__init__(transform=transform, *args, **kwargs)
    481 
    482         self.size_prefix = tuple(self.shape[:-1])

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in __init__(self, shape, dtype, defaults, *args, **kwargs)
    191         if dtype is None:
    192             dtype = theano.config.floatX
--> 193         super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
    194 
    195 

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\distributions\distribution.py in __init__(self, shape, dtype, testval, defaults, transform, broadcastable)
     61                  transform=None, broadcastable=None):
     62         self.shape = np.atleast_1d(shape)
---> 63         if False in (np.floor(self.shape) == self.shape):
     64             raise TypeError("Expected int elements in shape")
     65         self.dtype = dtype

TypeError: must be real number, not TensorVariable

I could get rid of it by passing shape=K to both the Gamma and the Dirichlet variables, but when I want to run the code, it crashes (I am working in a Jupyter notebook and I lose the server connection).
Apparently, the shape parameter is needed or I am doing something else wrong.

Yeah, you definitely need to pass the shape arg to the Gamma, otherwise it can’t infer the length of the vector you need – i.e as much as there are categories in your Dirichlet; shape=K here I think.

And yeah maybe you’ll need to also specify the shape to the Dirichlet, even though it’s a likelihood (my bad here, I forgot the Dirichlet is a little special), but try without first.

I am actually not quite sure how to assign the shapes.
The shape of my data is an (n,K)-array where n is the number of samples and K is the dimension of each sample. So, the Gamma distribution should have shape=K and the Dirichlet as well, right?

But unfortunately, my notebook crashes. Could this be a Jupyter problem?

The Dirichlet returns a vector of probabilities of the categories: [p_1, p_2, ..., p_k], with \sum_{i=1}^{k} p_k = 1. This is a generalization of the Beta distribution to more than two categories.

What do you mean by the notebook crashes? What’s the error you’re getting?

Hi, yes that’s why I use the Dirichlet to model my data: each sample is a distribution over K items but I do not now much about the underlying generative process. So I thought that using a Dirichlet would be appropriate and want to look at changing concentration parameters depending on subsets of the whole sample.

Initiating the model seems to work, I don’t get any error and can call both the gamma and theta variables. But when I try to run trace = pm.sample(draws=1000) (of course with in the context manager), my notebook server crashes with the error message:

Server Connection Error
A connection to the Jupyter server could not be established. JupyterLab will continue trying to reconnect. Check your network connection or Jupyter server configuration.

I am sure that calling sample is the cause for this but don’t know what I’d have to change.

When I reproduced the code in the ipython CLI, I got the following error message:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [gamma]
Sampling 4 chains, 0 divergences:   0%|            | 0/6000 [00:00<?, ?draws/s]T
raceback (most recent call last):
  File "<string>", line 1, in <module>
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spa
wn_main
    exitcode = _main(fd)
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 115, in _ma
in
    self = reduction.pickle.load(from_parent)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\compile\function_modul
e.py", line 1082, in _constructor_Function
    f = maker.create(input_storage, trustme=True)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\compile\function_modul
e.py", line 1715, in create
    input_storage=input_storage_lists, storage_map=storage_map)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\link.py", line 699
, in make_thunk
    storage_map=storage_map)[:3]
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\vm.py", line 1091,
 in make_all
    impl=impl))
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\op.py", line 955,
in make_thunk
    no_recycling)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\op.py", line 858,
in make_c_thunk
    output_storage=node_output_storage)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 1217,
 in make_thunk
    keep_lock=keep_lock)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 1157,
 in __compile__
    keep_lock=keep_lock)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 1623,
 in cthunk_factory
    module = get_module_cache().module_from_key(
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cc.py", line 48, i
n get_module_cache
    return cmodule.get_module_cache(config.compiledir, init_args=init_args)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cmodule.py", line
1587, in get_module_cache
    _module_cache = ModuleCache(dirname, **init_args)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cmodule.py", line
703, in __init__
    self.refresh()
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\cmodule.py", line
826, in refresh
    key_data = pickle.load(f)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\tensor\elemwise.py", l
ine 412, in __setstate__
    super(Elemwise, self).__setstate__(d)
  File "C:\ProgramData\Anaconda3\lib\site-packages\theano\gof\op.py", line 1160,
 in __setstate__
    self.__dict__.update(d)
KeyboardInterrupt

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in _mp_sample(draws
, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, mode
l, **kwargs)
   1058             with sampler:
-> 1059                 for draw in sampler:
   1060                     trace = traces[draw.chain - chain]

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\parallel_sampling.py in __iter_
_(self)
    393         while self._active:
--> 394             draw = ProcessAdapter.recv_draw(self._active)
    395             proc, is_last, draw, tuning, stats, warns = draw

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\parallel_sampling.py in recv_dr
aw(processes, timeout)
    283         pipes = [proc._msg_pipe for proc in processes]
--> 284         ready = multiprocessing.connection.wait(pipes)
    285         if not ready:

C:\ProgramData\Anaconda3\lib\multiprocessing\connection.py in wait(object_list,
timeout)
    858
--> 859             ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), t
imeout)
    860         finally:

C:\ProgramData\Anaconda3\lib\multiprocessing\connection.py in _exhaustive_wait(h
andles, timeout)
    790         while L:
--> 791             res = _winapi.WaitForMultipleObjects(L, False, timeout)
    792             if res == WAIT_TIMEOUT:

KeyboardInterrupt:

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-14-1a746dc8612f> in <module>
      2     gamma = pm.Gamma("gamma", alpha=1., beta=1., shape=K)
      3     theta = pm.Dirichlet("theta", a=gamma, shape=K, observed=x_)
----> 4     trace = pm.sample(draws=1000)
      5

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in sample(draws, st
ep, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, mod
el, random_seed, discard_tuned_samples, compute_convergence_checks, **kwargs)
    467         _print_step_hierarchy(step)
    468         try:
--> 469             trace = _mp_sample(**sample_args)
    470         except pickle.PickleError:
    471             _log.warning("Could not pickle model, sampling singlethreade
d.")

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in _mp_sample(draws
, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, mode
l, **kwargs)
   1078         return MultiTrace(traces)
   1079     except KeyboardInterrupt:
-> 1080         traces, length = _choose_chains(traces, tune)
   1081         return MultiTrace(traces)[:length]
   1082     finally:

C:\ProgramData\Anaconda3\lib\site-packages\pymc3\sampling.py in _choose_chains(t
races, tune)
   1094     lengths = [max(0, len(trace) - tune) for trace in traces]
   1095     if not sum(lengths):
-> 1096         raise ValueError("Not enough samples to build a trace.")
   1097
   1098     idxs = np.argsort(lengths)[::-1]

ValueError: Not enough samples to build a trace.

The sample that I passed to observed has shape=(12625,12) and I would assume that the sample size is not really the problem. To test that PyMC3 does not interprete x as a single sample of shape (12625,12), I passed shape=x.shape to the initialization of theta but got the same error.

Ooh, seems like you’re on Windows, so it could be a nasty multiprocessing issue…
Try doing pm.sample(cores=1)

I tried that and also using Ubuntu but now I get SamplingError: Bad initial energy. Sorry for the back and forth, but I have no idea what could cause this.

Update: Looking at this post, I got it to run on Windows WSL (Ubuntu) with

with pm.Model() as model:
    gamma = pm.Gamma("gamma", alpha=1., beta=1., shape=K)
    theta = pm.Dirichlet("theta", a=gamma, shape=K, observed=x)

    trace = pm.sample(draws=1000, cores=2, init="adapt_diag", tune=1000)

Everything seems to work fine for the first 2000 iterations (first half), but for the second half I get only divergences and the output reads The chain contains only diverging samples. The model is probably misspecified.

1 Like

Yeah, Bad initial energy usually comes from too vague priors, messed up initialization of NUTS and / or missing values somewhere. Good job on sorting this out :ok_hand:

I’d probably do some prior and posterior predictive checks to try and understand why your model is misspecified. Here’s an example of how to do it.

Thanks a lot for your help!

1 Like