Dirichlet distribution in 4.0.0b4

I was not able to reproduce this on macOS and not even with a different user on the same machine running Ubuntu. I am very confused now where this error comes from. Using my account on the Ubuntu machine I can reproduce it in multiple freshly setup environments, making sure PYTHONPATH is not set, cleaning the PATH environment variable and running the example from different working directories. What could be other causes?

I set aesara.config.exception_verbosity='high' to receive this slightly extended output:

Traceback (most recent call last):
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/compile/function/types.py", line 964, in __call__
    self.fn()
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/op.py", line 522, in rval
    r = p(n, [x[0] for x in i], o)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 48, in perform
    raise NotImplementedError(
NotImplementedError: These `Op`s should be removed from graphs used for computation.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 3, in <module>
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/sampling.py", line 487, in sample
    model.check_start_vals(ip)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/model.py", line 1680, in check_start_vals
    initial_eval = self.point_logps(point=elem)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/model.py", line 1721, in point_logps
    self.compile_fn(factor_logps_fn)(point),
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/pymc/model.py", line 1820, in __call__
    return self.f(**state)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/compile/function/types.py", line 977, in __call__
    raise_with_op(
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/link/utils.py", line 538, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/compile/function/types.py", line 964, in __call__
    self.fn()
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/op.py", line 522, in rval
    r = p(n, [x[0] for x in i], o)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 48, in perform
    raise NotImplementedError(
NotImplementedError: These `Op`s should be removed from graphs used for computation.
Apply node that caused the error: TransformedVariable(Softmax{axis=0}.0, a_simplex__)
Toposort index: 20
Inputs types: [TensorType(float64, (None,)), TensorType(float64, (None,))]
Inputs shapes: [(3,), (2,)]
Inputs strides: [(8,), (8,)]
Inputs values: [array([0.33333333, 0.33333333, 0.33333333]), array([0., 0.])]
Inputs type_num: [12, 12]
Outputs clients: [[Elemwise{eq,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0}), Elemwise{gt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 1}), Elemwise{lt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0})]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 203, in apply
    return self.default_transform_opt.optimize(fgraph)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 103, in optimize
    ret = self.apply(fgraph, *args, **kwargs)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 1960, in apply
    nb += self.process_node(fgraph, node)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 1850, in process_node
    replacements = lopt.transform(fgraph, node)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/opt.py", line 1055, in transform
    return self.fn(fgraph, node)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 148, in transform_values
    new_value_var = transformed_variable(
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aesara/graph/op.py", line 294, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/dotto/.conda/envs/pymctest3/lib/python3.10/site-packages/aeppl/transforms.py", line 45, in make_node
    return Apply(self, [tran_value, value], [tran_value.type()])

Debug print of the apply node: 
TransformedVariable [id A] <TensorType(float64, (None,))> 'a_simplex___simplex'   

Storage map footprint:
 - Softmax{axis=0}.0, Shape: (3,), ElemSize: 8 Byte(s), TotalSize: 24 Byte(s)
 - TensorConstant{(3,) of 0.0}, Shape: (3,), ElemSize: 8 Byte(s), TotalSize: 24 Byte(s)
 - a_simplex__, Input, Shape: (2,), ElemSize: 8 Byte(s), TotalSize: 16 Byte(s)
 - TensorConstant{(3,) of 0.0}, Shape: (3,), ElemSize: 4 Byte(s), TotalSize: 12 Byte(s)
 - Sum{acc_dtype=float64}.0, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - InplaceDimShuffle{x}.0, Shape: (1,), ElemSize: 8 Byte(s), TotalSize: 8 Byte(s)
 - TensorConstant{(1,) of 0.0}, Shape: (1,), ElemSize: 8 Byte(s), TotalSize: 8 Byte(s)
 - TensorConstant{1}, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - TensorConstant{(1,) of -1.0}, Shape: (1,), ElemSize: 8 Byte(s), TotalSize: 8 Byte(s)
 - argmax, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - TensorConstant{0.6931471805599453}, Shape: (), ElemSize: 8 Byte(s), TotalSize: 8.0 Byte(s)
 - TensorConstant{-inf}, Shape: (), ElemSize: 4 Byte(s), TotalSize: 4.0 Byte(s)
 - TensorConstant{0}, Shape: (), ElemSize: 1 Byte(s), TotalSize: 1.0 Byte(s)
 - TensorConstant{(1,) of 0}, Shape: (1,), ElemSize: 1 Byte(s), TotalSize: 1 Byte(s)
 - TensorConstant{(1,) of 1}, Shape: (1,), ElemSize: 1 Byte(s), TotalSize: 1 Byte(s)
 TotalSize: 131.0 Byte(s) 0.000 GB
 TotalSize inputs: 91.0 Byte(s) 0.000 GB