PyMC/PyTensor Error with Multivariate Normal Distribution: Incompatible Elemwise Input Shape

Hello, I’m working with PyMC and encountering an issue when defining a multivariate normal distribution. My objective is to fit a 2D Gaussian to data (x1 and x2) of stars, incorporating observational errors (error_x1 and error_x2).

However, when I define my model and attempt to run the sampler, I encounter a ValueError related to incompatible shapes. The error message is: ValueError: Incompatible Elemwise input shapes [(247, 2), (2, 2)].

I’m using PyMC 5.10.3 and PyTensor 2.18.4, MacOS Sonoma, Macbook Air M2.

Here’s the relevant portion of my code:

x1, x2, error_x1 and error_x2 are arrays of len=246

    mean_x1 = np.median(x1)
    mean_x2 = np.median(x2)
    with pm.Model() as pm_model:
        # Priors for 2D Gaussian parameters
        mu_x1 = pm.Normal('mu_x1', mu=mean_x1, sigma=5)  
        mu_x2 = pm.Normal('mu_x2', mu=mean_x2, sigma=5) 
        sigma_x1 = pm.HalfNormal('sigma_x1', sigma=5,shape=len(x1))
        sigma_x2 = pm.HalfNormal('sigma_x2', sigma=5,shape=len(x2))
        corr = pm.Uniform('corr', lower=-1, upper=1)

        # Adjusted covariance matrix to include observational errors
        cov = pm.math.stack([
            [sigma_x1**2 + error_x1**2, corr * sigma_x1 * sigma_x2],
            [corr * sigma_x1 * sigma_x2, sigma_x2**2 + error_x2**2]
        ])
        # Multivariate normal likelihood
        obs = pm.MvNormal('obs', mu=pm.math.stack([mu_x1, mu_x2],axis=1), cov=cov, observed=np.stack([x1, x2], axis=1))

and the error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:437, in Elemwise.get_output_info(self, dim_shuffle, *inputs)
    435 try:
    436     out_shapes = [
--> 437         [
    438             broadcast_static_dim_lengths(shape)
    439             for shape in zip(*[inp.type.shape for inp in inputs])
    440         ]
    441     ] * shadow.nout
    442 except ValueError:

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:438, in <listcomp>(.0)
    435 try:
    436     out_shapes = [
    437         [
--> 438             broadcast_static_dim_lengths(shape)
    439             for shape in zip(*[inp.type.shape for inp in inputs])
    440         ]
    441     ] * shadow.nout
    442 except ValueError:

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/utils.py:163, in broadcast_static_dim_lengths(dim_lengths)
    162 if len(dim_lengths_set) > 1:
--> 163     raise ValueError
    164 return tuple(dim_lengths_set)[0]

ValueError: 

During handling of the above exception, another exception occurred:

File ~/miniconda3/lib/python3.11/site-packages/pymc/sampling/jax.py:271, in _get_batched_jittered_initial_points(model, chains, initvals, random_seed, jitter, jitter_max_retries)
    254 def _get_batched_jittered_initial_points(
    255     model: Model,
    256     chains: int,
   (...)
    260     jitter_max_retries: int = 10,
    261 ) -> Union[np.ndarray, List[np.ndarray]]:
    262     """Get jittered initial point in format expected by NumPyro MCMC kernel
    263 
    264     Returns
   (...)
    268         Each item has shape `(chains, *var.shape)`
    269     """
--> 271     initial_points = _init_jitter(
    272         model,
    273         initvals,
    274         seeds=_get_seeds_per_chain(random_seed, chains),
    275         jitter=jitter,
    276         jitter_max_retries=jitter_max_retries,
    277     )
    278     initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
    279     if chains == 1:

File ~/miniconda3/lib/python3.11/site-packages/pymc/sampling/mcmc.py:1257, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   1255 if i < jitter_max_retries:
   1256     try:
-> 1257         model.check_start_vals(point)
   1258     except SamplingError:
   1259         # Retry with a new seed
   1260         seed = rng.randint(2**30, dtype=np.int64)

File ~/miniconda3/lib/python3.11/site-packages/pymc/model/core.py:1657, in Model.check_start_vals(self, start)
   1651     valid_keys = ", ".join(value_names_set)
   1652     raise KeyError(
   1653         "Some start parameters do not appear in the model!\n"
   1654         f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
   1655     )
-> 1657 initial_eval = self.point_logps(point=elem)
   1659 if not all(np.isfinite(v) for v in initial_eval.values()):
   1660     raise SamplingError(
   1661         "Initial evaluation of model at starting point failed!\n"
   1662         f"Starting values:\n{elem}\n\n"
   1663         f"Logp initial evaluation results:\n{initial_eval}\n"
   1664         "You can call `model.debug()` for more details."
   1665     )

File ~/miniconda3/lib/python3.11/site-packages/pymc/model/core.py:1687, in Model.point_logps(self, point, round_vals)
   1684     point = self.initial_point()
   1686 factors = self.basic_RVs + self.potentials
-> 1687 factor_logps_fn = [pt.sum(factor) for factor in self.logp(factors, sum=False)]
   1688 return {
   1689     factor.name: np.round(np.asarray(factor_logp), round_vals)
   1690     for factor, factor_logp in zip(
   (...)
   1693     )
   1694 }

File ~/miniconda3/lib/python3.11/site-packages/pymc/model/core.py:727, in Model.logp(self, vars, jacobian, sum)
    725 rv_logps: List[TensorVariable] = []
    726 if rvs:
--> 727     rv_logps = transformed_conditional_logp(
    728         rvs=rvs,
    729         rvs_to_values=self.rvs_to_values,
    730         rvs_to_transforms=self.rvs_to_transforms,
    731         jacobian=jacobian,
    732     )
    733     assert isinstance(rv_logps, list)
    735 # Replace random variables by their value variables in potential terms

File ~/miniconda3/lib/python3.11/site-packages/pymc/logprob/basic.py:611, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    608     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore
    610 kwargs.setdefault("warn_rvs", False)
--> 611 temp_logp_terms = conditional_logp(
    612     rvs_to_values,
    613     extra_rewrites=transform_rewrite,
    614     use_jacobian=jacobian,
    615     **kwargs,
    616 )
    618 # The function returns the logp for every single value term we provided to it.
    619 # This includes the extra values we plugged in above, so we filter those we
    620 # actually wanted in the same order they were given in.
    621 logp_terms = {}

File ~/miniconda3/lib/python3.11/site-packages/pymc/logprob/basic.py:541, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    538 q_values = remapped_vars[: len(q_values)]
    539 q_rv_inputs = remapped_vars[len(q_values) :]
--> 541 q_logprob_vars = _logprob(
    542     node.op,
    543     q_values,
    544     *q_rv_inputs,
    545     **kwargs,
    546 )
    548 if not isinstance(q_logprob_vars, (list, tuple)):
    549     q_logprob_vars = [q_logprob_vars]

File ~/miniconda3/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/lib/python3.11/site-packages/pymc/distributions/distribution.py:193, in DistributionMeta.__new__.<locals>.logp(op, values, *dist_params, **kwargs)
    191 dist_params = dist_params[3:]
    192 (value,) = values
--> 193 return class_logp(value, *dist_params)

File ~/miniconda3/lib/python3.11/site-packages/pymc/distributions/multivariate.py:270, in MvNormal.logp(value, mu, cov)
    256 def logp(value, mu, cov):
    257     """
    258     Calculate log-probability of Multivariate Normal distribution
    259     at specified value.
   (...)
    268     TensorVariable
    269     """
--> 270     quaddist, logdet, ok = quaddist_chol(value, mu, cov)
    271     k = floatX(value.shape[-1])
    272     norm = -0.5 * k * pm.floatX(np.log(2 * np.pi))

File ~/miniconda3/lib/python3.11/site-packages/pymc/distributions/multivariate.py:155, in quaddist_chol(value, mu, cov)
    152 else:
    153     onedim = False
--> 155 delta = value - mu
    156 chol_cov = nan_lower_cholesky(cov)
    158 diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1)

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/variable.py:125, in _tensor_py_operators.__sub__(self, other)
    121 def __sub__(self, other):
    122     # See explanation in __add__ for the error caught
    123     # and the return value in that case
    124     try:
--> 125         return pt.math.sub(self, other)
    126     except (NotImplementedError, TypeError):
    127         return NotImplemented

File ~/miniconda3/lib/python3.11/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
    253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    254 
    255 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    292 
    293 """
    294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
    297 if config.compute_test_value != "off":
    298     compute_test_value(node)

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:481, in Elemwise.make_node(self, *inputs)
    475 """
    476 If the inputs have different number of dimensions, their shape
    477 is left-completed to the greatest number of dimensions with 1s
    478 using DimShuffle.
    479 """
    480 inputs = [as_tensor_variable(i) for i in inputs]
--> 481 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
    482 outputs = [
    483     TensorType(dtype=dtype, shape=shape)()
    484     for dtype, shape in zip(out_dtypes, out_shapes)
    485 ]
    486 return Apply(self, inputs, outputs)

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:443, in Elemwise.get_output_info(self, dim_shuffle, *inputs)
    436     out_shapes = [
    437         [
    438             broadcast_static_dim_lengths(shape)
    439             for shape in zip(*[inp.type.shape for inp in inputs])
    440         ]
    441     ] * shadow.nout
    442 except ValueError:
--> 443     raise ValueError(
    444         f"Incompatible Elemwise input shapes {[inp.type.shape for inp in inputs]}"
    445     )
    447 # inplace_pattern maps output idx -> input idx
    448 inplace_pattern = self.inplace_pattern

ValueError: Incompatible Elemwise input shapes [(247, 2), (2, 2)]

if I add the shape=len(x1) and shape=len(x2) to mu_x1 and mu_x2, that gives a similar error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:437, in Elemwise.get_output_info(self, dim_shuffle, *inputs)
    435 try:
    436     out_shapes = [
--> 437         [
    438             broadcast_static_dim_lengths(shape)
    439             for shape in zip(*[inp.type.shape for inp in inputs])
    440         ]
    441     ] * shadow.nout
    442 except ValueError:

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:438, in <listcomp>(.0)
    435 try:
    436     out_shapes = [
    437         [
--> 438             broadcast_static_dim_lengths(shape)
    439             for shape in zip(*[inp.type.shape for inp in inputs])
    440         ]
    441     ] * shadow.nout
    442 except ValueError:

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/utils.py:163, in broadcast_static_dim_lengths(dim_lengths)
    162 if len(dim_lengths_set) > 1:
--> 163     raise ValueError
    164 return tuple(dim_lengths_set)[0]

ValueError: 

During handling of the above exception, another exception occurred:

File ~/miniconda3/lib/python3.11/site-packages/pymc/distributions/distribution.py:369, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    366     elif observed is not None:
    367         kwargs["shape"] = tuple(observed.shape)
--> 369 rv_out = cls.dist(*args, **kwargs)
    371 rv_out = model.register_rv(
    372     rv_out,
    373     name,
   (...)
    378     initval=initval,
    379 )
    381 # add in pretty-printing support

File ~/miniconda3/lib/python3.11/site-packages/pymc/distributions/multivariate.py:245, in MvNormal.dist(cls, mu, cov, tau, chol, lower, **kwargs)
    243 cov = quaddist_matrix(cov, chol, tau, lower)
    244 # PyTensor is stricter about the shape of mu, than PyMC used to be
--> 245 mu, _ = pt.broadcast_arrays(mu, cov[..., -1])
    246 return super().dist([mu, cov], **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/extra_ops.py:1656, in broadcast_arrays(*args)
   1653 for i, a in enumerate(args):
   1654     # We use indexing and not identity in case there are duplicated variables
   1655     others = [a for j, a in enumerate(args) if j != i]
-> 1656     brodacasted_vars.append(broadcast_with_others(a, others))
   1658 return brodacasted_vars

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/extra_ops.py:1649, in broadcast_arrays.<locals>.broadcast_with_others(a, others)
   1647 def broadcast_with_others(a, others):
   1648     for other in others:
-> 1649         a = second(other, a)
   1650     return a

File ~/miniconda3/lib/python3.11/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
    253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    254 
    255 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    292 
    293 """
    294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
    297 if config.compute_test_value != "off":
    298     compute_test_value(node)

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:481, in Elemwise.make_node(self, *inputs)
    475 """
    476 If the inputs have different number of dimensions, their shape
    477 is left-completed to the greatest number of dimensions with 1s
    478 using DimShuffle.
    479 """
    480 inputs = [as_tensor_variable(i) for i in inputs]
--> 481 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
    482 outputs = [
    483     TensorType(dtype=dtype, shape=shape)()
    484     for dtype, shape in zip(out_dtypes, out_shapes)
    485 ]
    486 return Apply(self, inputs, outputs)

File ~/miniconda3/lib/python3.11/site-packages/pytensor/tensor/elemwise.py:443, in Elemwise.get_output_info(self, dim_shuffle, *inputs)
    436     out_shapes = [
    437         [
    438             broadcast_static_dim_lengths(shape)
    439             for shape in zip(*[inp.type.shape for inp in inputs])
    440         ]
    441     ] * shadow.nout
    442 except ValueError:
--> 443     raise ValueError(
    444         f"Incompatible Elemwise input shapes {[inp.type.shape for inp in inputs]}"
    445     )
    447 # inplace_pattern maps output idx -> input idx
    448 inplace_pattern = self.inplace_pattern

ValueError: Incompatible Elemwise input shapes [(2, 2), (247, 2)]

Try to call .eval on the cov and mu variables to see if they have the shape you expected them to have