Hello PyMC team and users

I am trying to implement the Chinese Restaurant Process to then make a latent causes model and reproduce in PyMC the results of this paper: Context, learning, and extinction - PubMed

(the paper actually used a particle filter instead of MCMC for inference)

I have managed to create the following ādistā function that seems to work well in pytensor as a standalone function.

```
import numpy as np
import pymc as pm
from pytensor.tensor import TensorVariable
import pytensor
import pytensor.tensor as T
from pymc.pytensorf import collect_default_updates
def dist(
alpha: TensorVariable,
size: TensorVariable,
) -> TensorVariable:
# Find out when new tables are created
first_roll = pm.Uniform.dist(size=size)
customers = T.cumsum(T.ones(size)) - 1
newtable_chances = 1.0 * alpha / (alpha + customers)
newtables = first_roll < newtable_chances
newtables_int = newtables.astype('int')
newtables_numbers = newtables_int * T.cumsum(newtables_int)
assignments = newtables_numbers
# Function to assign old tables whenever a new table is not created
def assign_table(isnew, index, prior, nonseq):
counts = T.unique(prior[:(index)], return_counts=True)[1]
probs = counts/counts.sum()
choice = pm.Categorical.dist(p=probs) + 1
zeros_subtensor = prior[index]
value1 = choice * (1 - isnew.astype('int')) # Write the old table or..
value2 = nonseq[index] * isnew.astype('int') # ..rewrite the new table
output = T.set_subtensor(zeros_subtensor, value1 + value2)
return output, collect_default_updates(output)
indexes = pytensor.tensor.arange(newtables_numbers.shape[0])
result, updates = pytensor.scan(fn=assign_table,
outputs_info=T.ones_like(assignments, dtype='int'),
sequences=[newtables[1:], indexes[1:]],
non_sequences=assignments)
return result[-1].astype('int')
print("10 customers with low alpha:", dist(1, 10).eval())
print("10 customers with very high alpha:", dist(100, 10).eval())
```

Running the above gives:

```
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
10 customers with low alpha: [1 2 1 1 3 2 3 3 1 1]
10 customers with very high alpha: [ 1 2 3 4 5 6 7 8 9 10]
```

However, once I try to use the dist function inside CustomDist, it gives the following error:

```
with pm.Model() as m:
alpha = pm.Categorical('alpha', p=[0.25, 0.25, 0.25, 0.25]) + 1
crp = pm.CustomDist(
"crp",
alpha,
dist=dist,
observed=[1,1,1,2,1,1,1],
)
prior = pm.sample_prior_predictive()
posterior = pm.sample()
```

```
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[2], line 3
1 with pm.Model() as m:
2 alpha = pm.Categorical('alpha', p=[0.25, 0.25, 0.25, 0.25]) + 1
----> 3 crp = pm.CustomDist(
4 "crp",
5 alpha,
6 dist=dist,
7 observed=[1,1,1,2,1,1,1],
8 )
10 prior = pm.sample_prior_predictive()
11 posterior = pm.sample()
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:958, in CustomDist.__new__(cls, name, dist, random, logp, logcdf, moment, ndim_supp, ndims_params, dtype, *dist_params, **kwargs)
956 if dist is not None:
957 kwargs.setdefault("class_name", f"CustomDist_{name}")
--> 958 return _CustomSymbolicDist(
959 name,
960 *dist_params,
961 dist=dist,
962 logp=logp,
963 logcdf=logcdf,
964 moment=moment,
965 ndim_supp=ndim_supp,
966 **kwargs,
967 )
968 else:
969 kwargs.setdefault("class_name", f"CustomDist_{name}")
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:308, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
305 elif observed is not None:
306 kwargs["shape"] = tuple(observed.shape)
--> 308 rv_out = cls.dist(*args, **kwargs)
310 rv_out = model.register_rv(
311 rv_out,
312 name,
(...)
317 initval=initval,
318 )
320 # add in pretty-printing support
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:622, in _CustomSymbolicDist.dist(cls, dist, logp, logcdf, moment, ndim_supp, dtype, class_name, *dist_params, **kwargs)
614 if moment is None:
615 moment = functools.partial(
616 default_moment,
617 rv_name=class_name,
618 has_fallback=True,
619 ndim_supp=ndim_supp,
620 )
--> 622 return super().dist(
623 dist_params,
624 class_name=class_name,
625 logp=logp,
626 logcdf=logcdf,
627 dist=dist,
628 moment=moment,
629 ndim_supp=ndim_supp,
630 **kwargs,
631 )
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:385, in Distribution.dist(cls, dist_params, shape, **kwargs)
383 ndim_supp = getattr(cls.rv_op, "ndim_supp", None)
384 if ndim_supp is None:
--> 385 ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
386 create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
387 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:651, in _CustomSymbolicDist.rv_op(cls, dist, logp, logcdf, moment, size, ndim_supp, class_name, *dist_params)
647 dummy_dist_params = [dist_param.type() for dist_param in dist_params]
648 with new_or_existing_block_model_access(
649 error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API"
650 ):
--> 651 dummy_rv = dist(*dummy_dist_params, dummy_size_param)
652 dummy_params = [dummy_size_param] + dummy_dist_params
653 dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
Cell In[1], line 37, in dist(alpha, size)
34 return output, collect_default_updates(output)
36 indexes = pytensor.tensor.arange(newtables_numbers.shape[0])
---> 37 result, updates = pytensor.scan(fn=assign_table,
38 outputs_info=T.ones_like(assignments, dtype='int'),
39 sequences=[newtables[1:], indexes[1:]],
40 non_sequences=assignments)
42 return result[-1].astype('int')
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/scan/basic.py:744, in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
738 arg.name = init_out["initial"].name + "[t-1]"
740 # We need now to allocate space for storing the output and copy
741 # the initial state over. We do this using the expand function
742 # defined in scan utils
743 sit_sot_scan_inputs.append(
--> 744 expand_empty(
745 unbroadcast(shape_padleft(actual_arg), 0),
746 actual_n_steps,
747 )
748 )
750 sit_sot_inner_slices.append(actual_arg)
751 if i in return_steps:
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/scan/utils.py:235, in expand_empty(tensor_var, size)
233 shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)]
234 new_shape = [size + shapes[0]] + shapes[1:]
--> 235 empty = AllocEmpty(tensor_var.dtype)(*new_shape)
237 ret = set_subtensor(empty[: shapes[0]], tensor_var)
238 ret.tag.nan_guard_mode_check = False
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
254
255 This method is just a wrapper around :meth:`Op.make_node`.
(...)
292
293 """
294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
297 if config.compute_test_value != "off":
298 compute_test_value(node)
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/basic.py:3835, in AllocEmpty.make_node(self, *_shape)
3834 def make_node(self, *_shape):
-> 3835 _shape, static_shape = infer_static_shape(_shape)
3836 otype = TensorType(dtype=self.dtype, shape=static_shape)
3837 output = otype()
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/basic.py:1386, in infer_static_shape(shape)
1382 raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
1384 sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
-> 1386 shape_fg = FunctionGraph(
1387 outputs=sh,
1388 features=[ShapeFeature()],
1389 clone=True,
1390 )
1391 folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
1392 static_shape = tuple(
1393 s.data.item() if isinstance(s, Constant) else None for s in folded_shape
1394 )
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:157, in FunctionGraph.__init__(self, inputs, outputs, features, clone, update_mapping, **clone_kwds)
154 self.add_input(in_var, check=False)
156 for output in outputs:
--> 157 self.add_output(output, reason="init")
159 self.profile = None
160 self.update_mapping = update_mapping
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:167, in FunctionGraph.add_output(self, var, reason, import_missing)
165 """Add a new variable as an output to this `FunctionGraph`."""
166 self.outputs.append(var)
--> 167 self.import_var(var, reason=reason, import_missing=import_missing)
168 self.clients[var].append(("output", len(self.outputs) - 1))
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:308, in FunctionGraph.import_var(self, var, reason, import_missing)
306 # Imports the owners of the variables
307 if var.owner and var.owner not in self.apply_nodes:
--> 308 self.import_node(var.owner, reason=reason, import_missing=import_missing)
309 elif (
310 var.owner is None
311 and not isinstance(var, AtomicVariable)
312 and var not in self.inputs
313 ):
314 from pytensor.graph.null_type import NullType
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:389, in FunctionGraph.import_node(self, apply_node, check, reason, import_missing)
387 self.variables.add(input)
388 self.add_client(input, (node, i))
--> 389 self.execute_callbacks("on_import", node, reason)
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:727, in FunctionGraph.execute_callbacks(self, name, *args, **kwargs)
725 continue
726 tf0 = time.perf_counter()
--> 727 fn(self, *args, **kwargs)
728 self.execute_callbacks_times[feature] += time.perf_counter() - tf0
729 self.execute_callbacks_time += time.perf_counter() - t0
File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py:569, in ShapeFeature.on_import(self, fgraph, node, reason)
562 for i, d in enumerate(sh):
563 # Note: we ignore any shape element that is not typed (i.e.,
564 # does not have a 'dtype' attribute). This means there may
565 # still remain int elements that are int32 on 32-bit platforms,
566 # but this works with `local_useless_subtensor`, so for now we
567 # keep it this way. See #266 for a better long-term fix.
568 if getattr(d, "dtype", "int64") != "int64":
--> 569 assert d.dtype in discrete_dtypes, (node, d.dtype)
570 assert str(d.dtype) != "uint64", node
571 new_shape += sh[len(new_shape) : i + 1]
AssertionError: (CumOp{None, add}(Alloc.0), 'float64')
```

Can anyone point me in the right direction?