`pm.sample_prior_predictive()` fails with incorrect size info passed to the `rng_fn` function of the RandomVariable

Here’s the output from trying to do mcmc sample with the script above after changing ndim_supp from 0 to 1. It seems that there are some internal checks to make sure that at least one parameter has a dimension of 1.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 7
      4 z = pm.Uniform("z", lower=0.01, upper=0.99)
      5 t = pm.Uniform("t", lower=0.0, upper=0.6, initval=0.1)
----> 7 ddm = DDM("ddm", v=v, a=a, z=z, t=t, observed=dataset.values)
      8 # prior_predictives = pm.sample_prior_predictive(500)
     10 ddm_pymc_trace = pm.sample()

File ~/HSSM/.venv/lib/python3.9/site-packages/pymc/distributions/distribution.py:308, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    305     elif observed is not None:
    306         kwargs["shape"] = tuple(observed.shape)
--> 308 rv_out = cls.dist(*args, **kwargs)
    310 rv_out = model.register_rv(
    311     rv_out,
    312     name,
   (...)
    317     initval=initval,
    318 )
    320 # add in pretty-printing support

File ~/HSSM/src/hssm/distribution_utils/dist.py:276, in make_distribution.<locals>.SSMDistribution.dist(cls, **kwargs)
    272 dist_params = [
    273     pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls.params
    274 ]
    275 other_kwargs = {k: v for k, v in kwargs.items() if k not in cls.params}
--> 276 return super().dist(dist_params, **other_kwargs)

File ~/HSSM/.venv/lib/python3.9/site-packages/pymc/distributions/distribution.py:387, in Distribution.dist(cls, dist_params, shape, **kwargs)
    385     ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
    386 create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
--> 387 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
    389 rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
    390 rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")

File ~/HSSM/.venv/lib/python3.9/site-packages/pytensor/tensor/random/op.py:289, in RandomVariable.__call__(self, size, name, rng, dtype, *args, **kwargs)
    288 def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
--> 289     res = super().__call__(rng, size, dtype, *args, **kwargs)
    291     if name is not None:
    292         res.name = name

File ~/HSSM/.venv/lib/python3.9/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
    253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    254 
    255 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    292 
    293 """
    294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
    297 if config.compute_test_value != "off":
    298     compute_test_value(node)

File ~/HSSM/.venv/lib/python3.9/site-packages/pytensor/tensor/random/op.py:334, in RandomVariable.make_node(self, rng, size, dtype, *dist_params)
    329 elif not isinstance(rng.type, RandomType):
    330     raise TypeError(
    331         "The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
    332     )
--> 334 shape = self._infer_shape(size, dist_params)
    335 _, static_shape = infer_static_shape(shape)
    336 dtype = self.dtype or dtype

File ~/HSSM/.venv/lib/python3.9/site-packages/pytensor/tensor/random/op.py:210, in RandomVariable._infer_shape(self, size, dist_params, param_shapes)
    208         return size
    209     else:
--> 210         supp_shape = self._supp_shape_from_params(
    211             dist_params, param_shapes=param_shapes
    212         )
    213         return tuple(size) + tuple(supp_shape)
    215 # Broadcast the parameters

File ~/HSSM/.venv/lib/python3.9/site-packages/pytensor/tensor/random/op.py:160, in RandomVariable._supp_shape_from_params(self, dist_params, **kwargs)
    152 def _supp_shape_from_params(self, dist_params, **kwargs):
    153     """Determine the support shape of a `RandomVariable`'s output given its parameters.
    154 
    155     This does *not* consider the extra dimensions added by the `size` parameter
   (...)
    158     Defaults to `param_supp_shape_fn`.
    159     """
--> 160     return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs)

File ~/HSSM/.venv/lib/python3.9/site-packages/pytensor/tensor/random/op.py:78, in default_supp_shape_from_params(ndim_supp, dist_params, rep_param_idx, param_shapes)
     76 ref_param = dist_params[rep_param_idx]
     77 if ref_param.ndim < ndim_supp:
---> 78     raise ValueError(
     79         "Reference parameter does not match the "
     80         f"expected dimensions; {ref_param} has less than {ndim_supp} dim(s)."
     81     )
     82 return ref_param.shape[-ndim_supp:]

ValueError: Reference parameter does not match the expected dimensions; v has less than 1 dim(s).