It seems the Dirichlet distribution is not working in v 4.0.0b3 yet, although I cannot find anything about it in the release notes. Since mixture distributions have been implemented and there seem to be working examples in the new doc, I thought this already worked once. Is this a known WIP or a problem on my end? Here is an example:
import numpy as np
import pymc as pm
with pm.Model() as model:
a = pm.Dirichlet('a', np.ones(3))
pm.sample()
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
962 try:
963 outputs = (
--> 964 self.fn()
965 if output_subset is None
966 else self.fn(output_subset=output_subset)
967 )
968 except Exception:
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
518 @is_thunk_type
519 def rval(
520 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
521 ):
--> 522 r = p(n, [x[0] for x in i], o)
523 for o in node.outputs:
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
47 def perform(self, node, inputs, outputs):
---> 48 raise NotImplementedError(
49 "These `Op`s should be removed from graphs used for computation."
50 )
NotImplementedError: These `Op`s should be removed from graphs used for computation.
During handling of the above exception, another exception occurred:
NotImplementedError Traceback (most recent call last)
Input In [9], in <cell line: 4>()
4 with pm.Model() as model:
5 a = pm.Dirichlet('a', np.ones(3))
----> 6 pm.sample()
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/sampling.py:487, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
485 # One final check that shapes and logps at the starting points are okay.
486 for ip in initial_points:
--> 487 model.check_start_vals(ip)
488 _check_start_shape(model, ip)
490 sample_args = {
491 "draws": draws,
492 "step": step,
(...)
503 "discard_tuned_samples": discard_tuned_samples,
504 }
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1680, in Model.check_start_vals(self, start)
1674 valid_keys = ", ".join(self.named_vars.keys())
1675 raise KeyError(
1676 "Some start parameters do not appear in the model!\n"
1677 f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
1678 )
-> 1680 initial_eval = self.point_logps(point=elem)
1682 if not all(np.isfinite(v) for v in initial_eval.values()):
1683 raise SamplingError(
1684 "Initial evaluation of model at starting point failed!\n"
1685 f"Starting values:\n{elem}\n\n"
1686 f"Initial evaluation results:\n{initial_eval}"
1687 )
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1721, in Model.point_logps(self, point, round_vals)
1715 factors = self.basic_RVs + self.potentials
1716 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
1717 return {
1718 factor.name: np.round(np.asarray(factor_logp), round_vals)
1719 for factor, factor_logp in zip(
1720 factors,
-> 1721 self.compile_fn(factor_logps_fn)(point),
1722 )
1723 }
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1820, in PointFunc.__call__(self, state)
1819 def __call__(self, state):
-> 1820 return self.f(**state)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:977, in Function.__call__(self, *args, **kwargs)
975 if hasattr(self.fn, "thunks"):
976 thunk = self.fn.thunks[self.fn.position_of_error]
--> 977 raise_with_op(
978 self.maker.fgraph,
979 node=self.fn.nodes[self.fn.position_of_error],
980 thunk=thunk,
981 storage_map=getattr(self.fn, "storage_map", None),
982 )
983 else:
984 # old-style linkers raise their own exceptions
985 raise
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/link/utils.py:538, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
533 warnings.warn(
534 f"{exc_type} error does not allow us to add an extra error message"
535 )
536 # Some exception need extra parameter in inputs. So forget the
537 # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
961 t0_fn = time.time()
962 try:
963 outputs = (
--> 964 self.fn()
965 if output_subset is None
966 else self.fn(output_subset=output_subset)
967 )
968 except Exception:
969 restore_defaults()
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
518 @is_thunk_type
519 def rval(
520 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
521 ):
--> 522 r = p(n, [x[0] for x in i], o)
523 for o in node.outputs:
524 compute_map[o][0] = True
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
47 def perform(self, node, inputs, outputs):
---> 48 raise NotImplementedError(
49 "These `Op`s should be removed from graphs used for computation."
50 )
NotImplementedError: These `Op`s should be removed from graphs used for computation.
Apply node that caused the error: TransformedVariable(Softmax{axis=0}.0, a_simplex__)
Toposort index: 20
Inputs types: [TensorType(float64, (None,)), TensorType(float64, (None,))]
Inputs shapes: [(3,), (2,)]
Inputs strides: [(8,), (8,)]
Inputs values: [array([0.33333333, 0.33333333, 0.33333333]), array([0., 0.])]
Outputs clients: [[Elemwise{eq,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0}), Elemwise{gt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 1}), Elemwise{lt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0})]]
Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 203, in apply
return self.default_transform_opt.optimize(fgraph)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 103, in optimize
ret = self.apply(fgraph, *args, **kwargs)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1960, in apply
nb += self.process_node(fgraph, node)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1850, in process_node
replacements = lopt.transform(fgraph, node)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1055, in transform
return self.fn(fgraph, node)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 148, in transform_values
new_value_var = transformed_variable(
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py", line 294, in __call__
node = self.make_node(*inputs, **kwargs)
File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 45, in make_node
return Apply(self, [tran_value, value], [tran_value.type()])
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
I installed everything with Conda and my versions are:
- PyMC 4.0.0b4
- Aesara 2.5.1
- aePPL 0.0.27
- Ubuntu 18.04.5 LTS