Dirichlet distribution in 4.0.0b4

It seems the Dirichlet distribution is not working in v 4.0.0b3 yet, although I cannot find anything about it in the release notes. Since mixture distributions have been implemented and there seem to be working examples in the new doc, I thought this already worked once. Is this a known WIP or a problem on my end? Here is an example:

import numpy as np
import pymc as pm

with pm.Model() as model:
    a = pm.Dirichlet('a', np.ones(3))
    pm.sample()
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
    962 try:
    963     outputs = (
--> 964         self.fn()
    965         if output_subset is None
    966         else self.fn(output_subset=output_subset)
    967     )
    968 except Exception:

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    518 @is_thunk_type
    519 def rval(
    520     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    521 ):
--> 522     r = p(n, [x[0] for x in i], o)
    523     for o in node.outputs:

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
     47 def perform(self, node, inputs, outputs):
---> 48     raise NotImplementedError(
     49         "These `Op`s should be removed from graphs used for computation."
     50     )

NotImplementedError: These `Op`s should be removed from graphs used for computation.

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
Input In [9], in <cell line: 4>()
      4 with pm.Model() as model:
      5     a = pm.Dirichlet('a', np.ones(3))
----> 6     pm.sample()

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/sampling.py:487, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    485 # One final check that shapes and logps at the starting points are okay.
    486 for ip in initial_points:
--> 487     model.check_start_vals(ip)
    488     _check_start_shape(model, ip)
    490 sample_args = {
    491     "draws": draws,
    492     "step": step,
   (...)
    503     "discard_tuned_samples": discard_tuned_samples,
    504 }

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1680, in Model.check_start_vals(self, start)
   1674     valid_keys = ", ".join(self.named_vars.keys())
   1675     raise KeyError(
   1676         "Some start parameters do not appear in the model!\n"
   1677         f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
   1678     )
-> 1680 initial_eval = self.point_logps(point=elem)
   1682 if not all(np.isfinite(v) for v in initial_eval.values()):
   1683     raise SamplingError(
   1684         "Initial evaluation of model at starting point failed!\n"
   1685         f"Starting values:\n{elem}\n\n"
   1686         f"Initial evaluation results:\n{initial_eval}"
   1687     )

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1721, in Model.point_logps(self, point, round_vals)
   1715 factors = self.basic_RVs + self.potentials
   1716 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
   1717 return {
   1718     factor.name: np.round(np.asarray(factor_logp), round_vals)
   1719     for factor, factor_logp in zip(
   1720         factors,
-> 1721         self.compile_fn(factor_logps_fn)(point),
   1722     )
   1723 }

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1820, in PointFunc.__call__(self, state)
   1819 def __call__(self, state):
-> 1820     return self.f(**state)

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:977, in Function.__call__(self, *args, **kwargs)
    975     if hasattr(self.fn, "thunks"):
    976         thunk = self.fn.thunks[self.fn.position_of_error]
--> 977     raise_with_op(
    978         self.maker.fgraph,
    979         node=self.fn.nodes[self.fn.position_of_error],
    980         thunk=thunk,
    981         storage_map=getattr(self.fn, "storage_map", None),
    982     )
    983 else:
    984     # old-style linkers raise their own exceptions
    985     raise

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/link/utils.py:538, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    533     warnings.warn(
    534         f"{exc_type} error does not allow us to add an extra error message"
    535     )
    536     # Some exception need extra parameter in inputs. So forget the
    537     # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
    961 t0_fn = time.time()
    962 try:
    963     outputs = (
--> 964         self.fn()
    965         if output_subset is None
    966         else self.fn(output_subset=output_subset)
    967     )
    968 except Exception:
    969     restore_defaults()

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    518 @is_thunk_type
    519 def rval(
    520     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    521 ):
--> 522     r = p(n, [x[0] for x in i], o)
    523     for o in node.outputs:
    524         compute_map[o][0] = True

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
     47 def perform(self, node, inputs, outputs):
---> 48     raise NotImplementedError(
     49         "These `Op`s should be removed from graphs used for computation."
     50     )

NotImplementedError: These `Op`s should be removed from graphs used for computation.
Apply node that caused the error: TransformedVariable(Softmax{axis=0}.0, a_simplex__)
Toposort index: 20
Inputs types: [TensorType(float64, (None,)), TensorType(float64, (None,))]
Inputs shapes: [(3,), (2,)]
Inputs strides: [(8,), (8,)]
Inputs values: [array([0.33333333, 0.33333333, 0.33333333]), array([0., 0.])]
Outputs clients: [[Elemwise{eq,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0}), Elemwise{gt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 1}), Elemwise{lt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0})]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 203, in apply
    return self.default_transform_opt.optimize(fgraph)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 103, in optimize
    ret = self.apply(fgraph, *args, **kwargs)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1960, in apply
    nb += self.process_node(fgraph, node)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1850, in process_node
    replacements = lopt.transform(fgraph, node)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1055, in transform
    return self.fn(fgraph, node)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 148, in transform_values
    new_value_var = transformed_variable(
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py", line 294, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 45, in make_node
    return Apply(self, [tran_value, value], [tran_value.type()])

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

I installed everything with Conda and my versions are:

  • PyMC 4.0.0b4
  • Aesara 2.5.1
  • aePPL 0.0.27
  • Ubuntu 18.04.5 LTS

The error above seems to indicate that Softmax is applied on the transformed RV of the Dirichlet distribution. However, the transformation currently used is the aeppl.transforms.Simplex which does not explicitly use the Softmax function:

1 Like

Two things. First, I might suggest submitting an issue regarding this. Second, I would strongly suggest upgrading the v4b4 due to a memory leak in v4b3.

Thanks @cluhmann. Sorry for the wrong title. In fact I did use 4.0.0b4 above. If this error is reproducible and not just an issue on my end then I will put this on GitHub.

Quite surprising error. The Dirichlet has been refactored a long time ago and its used in several tests. I haven’t seen any issues with it.

Hi @ricardoV94! I was surprised by this problem as well and am still not entirely sure there is nothing wrong on my end. Can you reproduce the error?

I replied on the GitHub issue. I couldn’t reproduce it on Colab. Can you try setting up a fresh environment?

I tried it in two new Conda environments, using pip to install the main branch of the pymc GitHub and pymc==4.0.0b4 just like in the Colab notebook. I also used Python 3.10.2 now. However, the same error still comes up. I will try it on a different OS now.

I was not able to reproduce this on macOS and not even with a different user on the same machine running Ubuntu. I am very confused now where this error comes from. Using my account on the Ubuntu machine I can reproduce it in multiple freshly setup environments, making sure PYTHONPATH is not set, cleaning the PATH environment variable and running the example from different working directories. What could be other causes?

I set aesara.config.exception_verbosity='high' to receive this slightly extended output:

Traceback (most recent call last):
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/compile/function/types.py", line 964, in __call__
    self.fn()
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/op.py", line 522, in rval
    r = p(n, [x[0] for x in i], o)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 48, in perform
    raise NotImplementedError(
NotImplementedError: These `Op`s should be removed from graphs used for computation.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 3, in <module>
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/sampling.py", line 487, in sample
    model.check_start_vals(ip)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/model.py", line 1680, in check_start_vals
    initial_eval = self.point_logps(point=elem)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/model.py", line 1721, in point_logps
    self.compile_fn(factor_logps_fn)(point),
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/model.py", line 1820, in __call__
    return self.f(**state)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/compile/function/types.py", line 977, in __call__
    raise_with_op(
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/link/utils.py", line 538, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/compile/function/types.py", line 964, in __call__
    self.fn()
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/op.py", line 522, in rval
    r = p(n, [x[0] for x in i], o)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 48, in perform
    raise NotImplementedError(
NotImplementedError: These `Op`s should be removed from graphs used for computation.
Apply node that caused the error: TransformedVariable(Softmax{axis=0}.0, a_simplex__)
Toposort index: 20
Inputs types: [TensorType(float64, (None,)), TensorType(float64, (None,))]
Inputs shapes: [(3,), (2,)]
Inputs strides: [(8,), (8,)]
Inputs values: [array([0.33333333, 0.33333333, 0.33333333]), array([0., 0.])]
Inputs type_num: [12, 12]
Outputs clients: [[Elemwise{eq,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0}), Elemwise{gt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 1}), Elemwise{lt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0})]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 203, in apply
    return self.default_transform_opt.optimize(fgraph)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 103, in optimize
    ret = self.apply(fgraph, *args, **kwargs)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 1960, in apply
    nb += self.process_node(fgraph, node)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 1850, in process_node
    replacements = lopt.transform(fgraph, node)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 1055, in transform
    return self.fn(fgraph, node)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 148, in transform_values
    new_value_var = transformed_variable(
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/op.py", line 294, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 45, in make_node
    return Apply(self, [tran_value, value], [tran_value.type()])

Debug print of the apply node: 
TransformedVariable [id A] <TensorType(float64, (None,))> 'a_simplex___simplex'   

Storage map footprint:
 - Softmax{axis=0}.0, Shape: (3,), ElemSize: 8 Byte(s), TotalSize: 24 Byte(s)
 - TensorConstant{(3,) of 0.0}, Shape: (3,), ElemSize: 8 Byte(s), TotalSize: 24 Byte(s)
 - a_simplex__, Input, Shape: (2,), ElemSize: 8 Byte(s), TotalSize: 16 Byte(s)
 - TensorConstant{(3,) of 0.0}, Shape: (3,), ElemSize: 4 Byte(s), TotalSize: 12 Byte(s)
 - Sum{acc_dtype=float64}.0, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - InplaceDimShuffle{x}.0, Shape: (1,), ElemSize: 8 Byte(s), TotalSize: 8 Byte(s)
 - TensorConstant{(1,) of 0.0}, Shape: (1,), ElemSize: 8 Byte(s), TotalSize: 8 Byte(s)
 - TensorConstant{1}, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - TensorConstant{(1,) of -1.0}, Shape: (1,), ElemSize: 8 Byte(s), TotalSize: 8 Byte(s)
 - argmax, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - TensorConstant{0.6931471805599453}, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - TensorConstant{-inf}, Shape: (), ElemSize: 4 Byte(s), TotalSize: 4.0 Byte(s)
 - TensorConstant{0}, Shape: (), ElemSize: 1 Byte(s), TotalSize: 1.0 Byte(s)
 - TensorConstant{(1,) of 0}, Shape: (1,), ElemSize: 1 Byte(s), TotalSize: 1 Byte(s)
 - TensorConstant{(1,) of 1}, Shape: (1,), ElemSize: 1 Byte(s), TotalSize: 1 Byte(s)
 TotalSize: 131.0 Byte(s) 0.000 GB
 TotalSize inputs: 91.0 Byte(s) 0.000 GB

I was able to solve the issue by removing this option from my ~/.aesararc:

[global]
optimizer = fast_compile
3 Likes