Implementing Chinese Restaurant Process - pytensor error on CustomDist

Hello PyMC team and users :slight_smile:
I am trying to implement the Chinese Restaurant Process to then make a latent causes model and reproduce in PyMC the results of this paper: Context, learning, and extinction - PubMed
(the paper actually used a particle filter instead of MCMC for inference)

I have managed to create the following “dist” function that seems to work well in pytensor as a standalone function.

import numpy as np
import pymc as pm
from pytensor.tensor import TensorVariable
import pytensor
import pytensor.tensor as T
from pymc.pytensorf import collect_default_updates


def dist(
    alpha: TensorVariable,
    size: TensorVariable,
) -> TensorVariable:

    # Find out when new tables are created
    first_roll = pm.Uniform.dist(size=size)
    customers = T.cumsum(T.ones(size)) - 1
    newtable_chances = 1.0 * alpha / (alpha + customers)
    newtables = first_roll < newtable_chances
    newtables_int = newtables.astype('int')
    newtables_numbers = newtables_int * T.cumsum(newtables_int)

    assignments = newtables_numbers


    # Function to assign old tables whenever a new table is not created
    def assign_table(isnew, index, prior, nonseq):
        counts = T.unique(prior[:(index)], return_counts=True)[1]
        probs = counts/counts.sum()
        choice = pm.Categorical.dist(p=probs) + 1
        zeros_subtensor = prior[index]
        value1 = choice * (1 - isnew.astype('int'))  # Write the old table or..
        value2 = nonseq[index] * isnew.astype('int')  # ..rewrite the new table
        output = T.set_subtensor(zeros_subtensor, value1 + value2)
        return output, collect_default_updates(output)

    indexes = pytensor.tensor.arange(newtables_numbers.shape[0])
    result, updates = pytensor.scan(fn=assign_table,
                                      outputs_info=T.ones_like(assignments, dtype='int'),
                                      sequences=[newtables[1:], indexes[1:]],
                                      non_sequences=assignments)

    return result[-1].astype('int')


print("10 customers with low alpha:", dist(1, 10).eval())
print("10 customers with very high alpha:", dist(100, 10).eval())

Running the above gives:

WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
10 customers with low alpha: [1 2 1 1 3 2 3 3 1 1]
10 customers with very high alpha: [ 1  2  3  4  5  6  7  8  9 10]

However, once I try to use the dist function inside CustomDist, it gives the following error:

with pm.Model() as m:
    alpha = pm.Categorical('alpha', p=[0.25, 0.25, 0.25, 0.25]) + 1
    crp = pm.CustomDist(
        "crp",
        alpha,
        dist=dist,
        observed=[1,1,1,2,1,1,1],
    )

    prior = pm.sample_prior_predictive()
    posterior = pm.sample()
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[2], line 3
      1 with pm.Model() as m:
      2     alpha = pm.Categorical('alpha', p=[0.25, 0.25, 0.25, 0.25]) + 1
----> 3     crp = pm.CustomDist(
      4         "crp",
      5         alpha,
      6         dist=dist,
      7         observed=[1,1,1,2,1,1,1],
      8     )
     10     prior = pm.sample_prior_predictive()
     11     posterior = pm.sample()

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:958, in CustomDist.__new__(cls, name, dist, random, logp, logcdf, moment, ndim_supp, ndims_params, dtype, *dist_params, **kwargs)
    956 if dist is not None:
    957     kwargs.setdefault("class_name", f"CustomDist_{name}")
--> 958     return _CustomSymbolicDist(
    959         name,
    960         *dist_params,
    961         dist=dist,
    962         logp=logp,
    963         logcdf=logcdf,
    964         moment=moment,
    965         ndim_supp=ndim_supp,
    966         **kwargs,
    967     )
    968 else:
    969     kwargs.setdefault("class_name", f"CustomDist_{name}")

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:308, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    305     elif observed is not None:
    306         kwargs["shape"] = tuple(observed.shape)
--> 308 rv_out = cls.dist(*args, **kwargs)
    310 rv_out = model.register_rv(
    311     rv_out,
    312     name,
   (...)
    317     initval=initval,
    318 )
    320 # add in pretty-printing support

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:622, in _CustomSymbolicDist.dist(cls, dist, logp, logcdf, moment, ndim_supp, dtype, class_name, *dist_params, **kwargs)
    614 if moment is None:
    615     moment = functools.partial(
    616         default_moment,
    617         rv_name=class_name,
    618         has_fallback=True,
    619         ndim_supp=ndim_supp,
    620     )
--> 622 return super().dist(
    623     dist_params,
    624     class_name=class_name,
    625     logp=logp,
    626     logcdf=logcdf,
    627     dist=dist,
    628     moment=moment,
    629     ndim_supp=ndim_supp,
    630     **kwargs,
    631 )

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:385, in Distribution.dist(cls, dist_params, shape, **kwargs)
    383 ndim_supp = getattr(cls.rv_op, "ndim_supp", None)
    384 if ndim_supp is None:
--> 385     ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
    386 create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
    387 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:651, in _CustomSymbolicDist.rv_op(cls, dist, logp, logcdf, moment, size, ndim_supp, class_name, *dist_params)
    647 dummy_dist_params = [dist_param.type() for dist_param in dist_params]
    648 with new_or_existing_block_model_access(
    649     error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API"
    650 ):
--> 651     dummy_rv = dist(*dummy_dist_params, dummy_size_param)
    652 dummy_params = [dummy_size_param] + dummy_dist_params
    653 dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

Cell In[1], line 37, in dist(alpha, size)
     34     return output, collect_default_updates(output)
     36 indexes = pytensor.tensor.arange(newtables_numbers.shape[0])
---> 37 result, updates = pytensor.scan(fn=assign_table,
     38                                   outputs_info=T.ones_like(assignments, dtype='int'),
     39                                   sequences=[newtables[1:], indexes[1:]],
     40                                   non_sequences=assignments)
     42 return result[-1].astype('int')

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/scan/basic.py:744, in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
    738     arg.name = init_out["initial"].name + "[t-1]"
    740 # We need now to allocate space for storing the output and copy
    741 # the initial state over. We do this using the expand function
    742 # defined in scan utils
    743 sit_sot_scan_inputs.append(
--> 744     expand_empty(
    745         unbroadcast(shape_padleft(actual_arg), 0),
    746         actual_n_steps,
    747     )
    748 )
    750 sit_sot_inner_slices.append(actual_arg)
    751 if i in return_steps:

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/scan/utils.py:235, in expand_empty(tensor_var, size)
    233 shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)]
    234 new_shape = [size + shapes[0]] + shapes[1:]
--> 235 empty = AllocEmpty(tensor_var.dtype)(*new_shape)
    237 ret = set_subtensor(empty[: shapes[0]], tensor_var)
    238 ret.tag.nan_guard_mode_check = False

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
    253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    254 
    255 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    292 
    293 """
    294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
    297 if config.compute_test_value != "off":
    298     compute_test_value(node)

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/basic.py:3835, in AllocEmpty.make_node(self, *_shape)
   3834 def make_node(self, *_shape):
-> 3835     _shape, static_shape = infer_static_shape(_shape)
   3836     otype = TensorType(dtype=self.dtype, shape=static_shape)
   3837     output = otype()

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/basic.py:1386, in infer_static_shape(shape)
   1382     raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
   1384 sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
-> 1386 shape_fg = FunctionGraph(
   1387     outputs=sh,
   1388     features=[ShapeFeature()],
   1389     clone=True,
   1390 )
   1391 folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
   1392 static_shape = tuple(
   1393     s.data.item() if isinstance(s, Constant) else None for s in folded_shape
   1394 )

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:157, in FunctionGraph.__init__(self, inputs, outputs, features, clone, update_mapping, **clone_kwds)
    154     self.add_input(in_var, check=False)
    156 for output in outputs:
--> 157     self.add_output(output, reason="init")
    159 self.profile = None
    160 self.update_mapping = update_mapping

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:167, in FunctionGraph.add_output(self, var, reason, import_missing)
    165 """Add a new variable as an output to this `FunctionGraph`."""
    166 self.outputs.append(var)
--> 167 self.import_var(var, reason=reason, import_missing=import_missing)
    168 self.clients[var].append(("output", len(self.outputs) - 1))

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:308, in FunctionGraph.import_var(self, var, reason, import_missing)
    306 # Imports the owners of the variables
    307 if var.owner and var.owner not in self.apply_nodes:
--> 308     self.import_node(var.owner, reason=reason, import_missing=import_missing)
    309 elif (
    310     var.owner is None
    311     and not isinstance(var, AtomicVariable)
    312     and var not in self.inputs
    313 ):
    314     from pytensor.graph.null_type import NullType

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:389, in FunctionGraph.import_node(self, apply_node, check, reason, import_missing)
    387         self.variables.add(input)
    388     self.add_client(input, (node, i))
--> 389 self.execute_callbacks("on_import", node, reason)

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/fg.py:727, in FunctionGraph.execute_callbacks(self, name, *args, **kwargs)
    725         continue
    726     tf0 = time.perf_counter()
--> 727     fn(self, *args, **kwargs)
    728     self.execute_callbacks_times[feature] += time.perf_counter() - tf0
    729 self.execute_callbacks_time += time.perf_counter() - t0

File ~/miniconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py:569, in ShapeFeature.on_import(self, fgraph, node, reason)
    562 for i, d in enumerate(sh):
    563     # Note: we ignore any shape element that is not typed (i.e.,
    564     # does not have a 'dtype' attribute). This means there may
    565     # still remain int elements that are int32 on 32-bit platforms,
    566     # but this works with `local_useless_subtensor`, so for now we
    567     # keep it this way. See #266 for a better long-term fix.
    568     if getattr(d, "dtype", "int64") != "int64":
--> 569         assert d.dtype in discrete_dtypes, (node, d.dtype)
    570         assert str(d.dtype) != "uint64", node
    571         new_shape += sh[len(new_shape) : i + 1]

AssertionError: (CumOp{None, add}(Alloc.0), 'float64')

Can anyone point me in the right direction?

Digging a bit further, it seems that CustomDist initalization wants to call the dist function with the following dummy variables:
<Scalar(int64, shape=())> <Vector(int64, shape=(0,))>

<Scalar(int64, shape=())> is fine as alpha parameter, but <Vector(int64, shape=(0,))> as size makes the graph building process of dist fail on the scan function.

dist(T.scalar(dtype=‘int64’), 0)
Out[7]: Subtensor{i}.0

dist(1,T.vector(shape=(0,), dtype=‘int64’))

AssertionError: (CumOp{None, add}(Alloc.0), ‘float64’)

Haven’t figured out the solution yet :slight_smile:

CustomDist is called twice, once without size, and once with, just to figure out the support dimensionality of the distribution. The second time it’s called with the actual size. So you can try to handle the case where size.type.shape==(0,) just so the second call happens.

There are however two greater issues, in that PyMC will not be able to guess the logp of your dist. The first is that your CustomDist involves two RVs, the first Uniform first_roll and the Categorical choice inside the Scan. Actually I don’t know if this is a problem in general, but definitely the graph between first_roll and assignments is too much for PyMC to figure out how to get the logprob. Either way, you could get around this by moving the first Uniform outside of the CustomDist and passing it as a parameter.

The more important issue is that PyMC does not know how to derive the logprob of a set_subtensor operation. If you can rewrite it to avoid it, it should be able to derive the whole logprob. Also PyMC doesn’t know what to do with the last astype(int) but looks like you should be able to avoid it.

I suggest you start with a simpler Scan RV (even if it’s not what you care about), to put a finger on where things fail. PyMC logprob inference is still new, and has a lot of operations it can’t handle. Unfortunately we don’t have useful messages besides “logprob could not be derived” which are not useful when you have so many operations that could have caused it to fail.

Thanks for the answer @ricardoV94 !

CustomDist is called twice, once without size, and once with, just to figure out the support dimensionality of the distribution. The second time it’s called with the actual size. So you can try to handle the case where size.type.shape==(0,) just so the second call happens.

Yes, that’s what I have been trying.

There are however two greater issues, in that PyMC will not be able to guess the logp of your dist. The first is that your CustomDist involves two RVs, the first Uniform first_roll and the Categorical choice inside the Scan. Actually I don’t know if this is a problem in general, but definitely the graph between first_roll and assignments is too much for PyMC to figure out how to get the logprob. Either way, you could get around this by moving the first Uniform outside of the CustomDist and passing it as a parameter.

Thanks, I will do.

The more important issue is that PyMC does not know how to derive the logprob of a set_subtensor operation. If you can rewrite it to avoid it, it should be able to derive the whole logprob. Also PyMC doesn’t know what to do with the last astype(int) but looks like you should be able to avoid it.

Yes, I can easily avoid the astype(‘int’): I added it because the error message seemed to hint that it expected an “int64” tensor type somewhere.
Is there a list of supported and not supported tensor functions somewhere? So that I know if there are other operations that I should avoid.

I suggest you start with a simpler Scan RV (even if it’s not what you care about), to put a finger on where things fail. PyMC logprob inference is still new, and has a lot of operations it can’t handle. Unfortunately we don’t have useful messages besides “logprob could not be derived” which are not useful when you have so many operations that could have caused it to fail

One of the thing that stopped my debugging is that the error message did not explicitly say that the logprob derivation was the problem. It seemed more of a shape or tensor type problem.
Now that I know that subtensor operation and the derivation of logprob is the problem I’ll try different approaches.
For example, instead of trying to pass dist to CustomDist, I could try to provide random + logp directly.

Thanks again, I’ll post again once I solve the issue (or find more problems :sweat_smile:).