Matrix Multiplication With Multiple Dimensions in PYMC Model

Ah that makes a lot of sense okay sorry I was totally overthinking it!

I’m really sorry to ask for your help again, I thought I would have been able to keep going, but I tried taking this information and making it into a hierarchical model, but I am not getting any success. Is there something I’m doing wrong with the shapes here as well? I can’t make my shapes align.

# create a variety id, product id, and x data for each input (we have 200 datapoints)
# we can think of this like store products take pokemon cards
# variety would be trading cards, product would be pokemon card pack, and 
# the x data could be a few inputs like popularity and sales
# y would be the demand for the pokemon cards (which we want to predict)
variety = np.random.randint(0, high=10, size=200)
product = np.random.randint(0, high=30, size=200)
x = np.random.normal(loc=0.0, scale=1.0, size=(200,2))
print(x.shape, variety.shape, product.shape)

# create the target data 
y = np.random.normal(loc=0.0, scale=1.0, size=200)
print(y.shape)

heir_model_variety = pm.Model(coords={'variety':np.arange(10),'product':np.arange(30),'xdims':np.arange(2)})
with heir_model_variety:
    # global parameters
    mu_m = pm.Normal('mu_m', mu=1, sigma=10, dims="xdims")
    sigma_m = pm.HalfNormal('sigma_m', sigma=10, dims="xdims")
    
    mu_b = pm.Normal('mu_b', mu=50_000, sigma=100_000)
    sigma_b = pm.HalfNormal('sigma_b', sigma=100_000)
    
    std = pm.HalfNormal('std', sigma=500_000)
    
    # variety parameters
    mu_m_variety = pm.Normal('mu_m_variety', mu=mu_m, sigma=sigma_m, dims=("variety","xdims"))
    sigma_m_variety = pm.HalfNormal('sigma_m_variety', sigma=10, dims=("variety","xdims"))
    
    mu_b_variety = pm.Normal('mu_b_variety', mu=mu_b, sigma=sigma_b, dims="variety")
    sigma_b_variety = pm.HalfNormal('sigma_b_variety', sigma=100_000, dims="variety")
    
    # product parameters
    m = pm.Normal('m', mu=mu_m_variety, sigma=sigma_m_variety, dims=("product","variety","xdims"))
    b = pm.Normal('b', mu=mu_b_variety, sigma=sigma_b_variety, dims=("product","variety"))
    
    xdata = pm.Data('xdata', x, mutable=True)
    variety_data = pm.Data('variety_data', variety, mutable=True)
    product_data = pm.Data('product_data', product, mutable=True)
    
    print(xdata.shape.eval())
    print(m.T.shape.eval())
    print(xdata.dot(m.T).shape.eval())
    print(b.T.shape.eval())
    mean = xdata.dot(m.T) + b.T
    print(mean.shape.eval())
    obs = pm.Normal('obs', mu=mean[np.arange(200),product_data,variety_data], sigma=std, observed=y)
    
    heir_trace_variety = pm.sample(tune=N_TUNE, return_inferencedata=True, chains=N_CHAINS, target_accept=TARGET_ACCEPT, cores=N_CORES)
(200, 2) (200,) (200,)
(200,)
[200   2]
[ 2 10 30]
[200   2  30]
[10 30]
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/compile/function/types.py:975, in Function.__call__(self, *args, **kwargs)
    973 try:
    974     outputs = (
--> 975         self.vm()
    976         if output_subset is None
    977         else self.vm(output_subset=output_subset)
    978     )
    979 except Exception:

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/graph/op.py:541, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    533 @is_thunk_type
    534 def rval(
    535     p=p,
   (...)
    539     params=params_val,
    540 ):
--> 541     r = p(n, [x[0] for x in i], o, params)
    542     for o in node.outputs:

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/raise_op.py:96, in CheckAndRaise.perform(self, node, inputs, outputs, params)
     95 if not np.all(conds):
---> 96     raise self.exc_type(self.msg)

AssertionError: Could not broadcast dimensions

During handling of the above exception, another exception occurred:

AssertionError                            Traceback (most recent call last)
Input In [9], in <cell line: 16>()
     44 print(b.T.shape.eval())
     45 mean = xdata.dot(m.T) + b.T
---> 46 print(mean.shape.eval())
     47 obs = pm.Normal('obs', mu=mean[np.arange(200),product_data,variety_data], sigma=std, observed=y)
     49 heir_trace_variety = pm.sample(tune=N_TUNE, return_inferencedata=True, chains=N_CHAINS, target_accept=TARGET_ACCEPT, cores=N_CORES)

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/graph/basic.py:602, in Variable.eval(self, inputs_to_values)
    599     self._fn_cache[inputs] = function(inputs, self)
    600 args = [inputs_to_values[param] for param in inputs]
--> 602 rval = self._fn_cache[inputs](*args)
    604 return rval

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/compile/function/types.py:988, in Function.__call__(self, *args, **kwargs)
    986     if hasattr(self.vm, "thunks"):
    987         thunk = self.vm.thunks[self.vm.position_of_error]
--> 988     raise_with_op(
    989         self.maker.fgraph,
    990         node=self.vm.nodes[self.vm.position_of_error],
    991         thunk=thunk,
    992         storage_map=getattr(self.vm, "storage_map", None),
    993     )
    994 else:
    995     # old-style linkers raise their own exceptions
    996     raise

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/link/utils.py:534, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    529     warnings.warn(
    530         f"{exc_type} error does not allow us to add an extra error message"
    531     )
    532     # Some exception need extra parameter in inputs. So forget the
    533     # extra long error message in that case.
--> 534 raise exc_value.with_traceback(exc_trace)

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/compile/function/types.py:975, in Function.__call__(self, *args, **kwargs)
    972 t0_fn = time.time()
    973 try:
    974     outputs = (
--> 975         self.vm()
    976         if output_subset is None
    977         else self.vm(output_subset=output_subset)
    978     )
    979 except Exception:
    980     restore_defaults()

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/graph/op.py:541, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    533 @is_thunk_type
    534 def rval(
    535     p=p,
   (...)
    539     params=params_val,
    540 ):
--> 541     r = p(n, [x[0] for x in i], o, params)
    542     for o in node.outputs:
    543         compute_map[o][0] = True

File ~/miniforge3/envs/pymc/lib/python3.10/site-packages/aesara/raise_op.py:96, in CheckAndRaise.perform(self, node, inputs, outputs, params)
     94 out[0] = val
     95 if not np.all(conds):
---> 96     raise self.exc_type(self.msg)

AssertionError: Could not broadcast dimensions
Apply node that caused the error: Assert{msg=Could not broadcast dimensions}(Abs.0, TensorFromScalar.0)
Toposort index: 32
Inputs types: [ScalarType(int64), TensorType(bool, ())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [10, array(False)]
Outputs clients: [[TensorFromScalar(Assert{msg=Could not broadcast dimensions}.0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
1 Like