Custom multivariate density via DensityDist

Thanks for your reply @ricardoV94. I had indeed tried to specify ndim_supp=1 following the dimensionality guide, but then I get an error I can’t seem to figure out. Specifically, an IndexError: list index out of range with the following error trace:

Error
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [170], in <cell line: 4>()
      2     return x[0] + x[1]  # not an actual density
      4 with pm.Model() as model:
----> 5     custom_dist = pm.DensityDist("custom_dist", logp=f, shape=2, ndim_supp=1)

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/pymc/distributions/distribution.py:809, in DensityDist.__new__(cls, name, logp, logcdf, random, moment, ndim_supp, ndims_params, dtype, *dist_params, **kwargs)
    806     return moment(rv, size, *dist_params)
    808 cls.rv_op = rv_op
--> 809 return super().__new__(cls, name, *dist_params, **kwargs)

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/pymc/distributions/distribution.py:263, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    259     raise TypeError(f"Name needs to be a string but got: {name}")
    261 # Create the RV and process dims and observed to determine
    262 # a shape by which the created RV may need to be resized.
--> 263 rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
    264     cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
    265 )
    267 if resize_shape:
    268     # A batch size was specified through `dims`, or implied by `observed`.
    269     rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/pymc/distributions/distribution.py:165, in _make_rv_and_resize_shape(cls, dims, model, observed, args, **kwargs)
    162 """Creates the RV and processes dims or observed to determine a resize shape."""
    163 # Create the RV without dims information, because that's not something tracked at the Aesara level.
    164 # If necessary we'll later replicate to a different size implied by already known dims.
--> 165 rv_out = cls.dist(*args, **kwargs)
    166 ndim_actual = rv_out.ndim
    167 resize_shape = None

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/pymc/distributions/distribution.py:813, in DensityDist.dist(cls, *args, **kwargs)
    811 @classmethod
    812 def dist(cls, *args, **kwargs):
--> 813     output = super().dist(args, **kwargs)
    814     if cls.rv_op.dtype == "floatX":
    815         dtype = aesara.config.floatX

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/pymc/distributions/distribution.py:351, in Distribution.dist(cls, dist_params, shape, **kwargs)
    346 create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
    347     shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp
    348 )
    349 # Create the RV with a `size` right away.
    350 # This is not necessarily the final result.
--> 351 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
    353 # Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
    354 if shape is not None and Ellipsis in shape:

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py:279, in RandomVariable.__call__(self, size, name, rng, dtype, *args, **kwargs)
    278 def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
--> 279     res = super().__call__(rng, size, dtype, *args, **kwargs)
    281     if name is not None:
    282         res.name = name

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/aesara/graph/op.py:296, in Op.__call__(self, *inputs, **kwargs)
    254 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    255 
    256 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    293 
    294 """
    295 return_list = kwargs.pop("return_list", False)
--> 296 node = self.make_node(*inputs, **kwargs)
    298 if config.compute_test_value != "off":
    299     compute_test_value(node)

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py:324, in RandomVariable.make_node(self, rng, size, dtype, *dist_params)
    319 elif not isinstance(rng.type, RandomType):
    320     raise TypeError(
    321         "The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
    322     )
--> 324 shape = self._infer_shape(size, dist_params)
    325 _, bcast = infer_broadcastable(shape)
    326 dtype = self.dtype or dtype

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py:252, in RandomVariable._infer_shape(self, size, dist_params, param_shapes)
    250     shape_supp = ()
    251 else:
--> 252     shape_supp = self._supp_shape_from_params(
    253         dist_params,
    254         param_shapes=param_shapes,
    255     )
    257 shape = tuple(shape_ind) + tuple(shape_supp)
    258 if not shape:

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py:162, in RandomVariable._supp_shape_from_params(self, dist_params, **kwargs)
    154 def _supp_shape_from_params(self, dist_params, **kwargs):
    155     """Determine the support shape of a `RandomVariable`'s output given its parameters.
    156 
    157     This does *not* consider the extra dimensions added by the `size` parameter
   (...)
    160     Defaults to `param_supp_shape_fn`.
    161     """
--> 162     return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs)

File ~/.micromamba/envs/bfr-py39/lib/python3.9/site-packages/aesara/tensor/random/op.py:73, in default_supp_shape_from_params(ndim_supp, dist_params, rep_param_idx, param_shapes)
     71     raise ValueError("ndim_supp must be greater than 0")
     72 if param_shapes is not None:
---> 73     ref_param = param_shapes[rep_param_idx]
     74     return (ref_param[-ndim_supp],)
     75 else:

IndexError: list index out of range