`Cannot drop a non-broadcastable dimension` and `rfft`

Running the following code yields ValueError: Cannot drop a non-broadcastable dimension: (False, False), [1].

# This is a bogus model that reproduces the error
n = 4
with pm.Model() as model:
    p = pm.Dirichlet("p", a=pt.ones(n))
    dft = pm.Deterministic("dft", fft.rfft(pt.shape_padaxis(p, 0))) # returns shape (1,n//2+1,2)
    dftreshape = pm.Deterministic("dftreshape", dft[0,:,0]) # returns shape (n//2+1,)
    L = pm.Potential("L", pt.sum(dftreshape)) # sum returns shape ()
samples = pm.sample(draws=1024, model=model)

The error comes from pytensor/tensor/elemwise.py#L175 by way of pymc/sampling/mcmc.py and pytensor/gradient.py.

The following links address similar errors, and suggest simple reshaping as a workaround. The first link seems the most relevant since it also involves rfft().

I have made various attempts at reshaping, as suggested in the links above, but none have avoided the aforementioned error.

The (False, False), [1] (:: {input_broadcastable}, {new_order}) portion of the error statement suggests that somewhere there is a two-dimensional object from which pytensor would like to drop the first dimension, but neither of the two dimensions are broadcastable. In response, I’ve made an attempt that inserts the line dft = pt.specify_broadcastable(dft, *[0,1,2]) and renders downstream variables broadcastable, but the aforementioned error is still thrown.

How should I should adjust this model to avoid this error during sampling?

Version info:

python                    3.12.2
pytensor                  2.25.2
pytensor-base             2.25.2
pymc                      5.16.2
pymc-base                 5.16.2
macOS 14.5
Apple M3 Max

Can you share a reproducible snippet?

This self-contained snippet reproduces the error.

import pytensor.tensor as pt
from pytensor.tensor import fft
import pymc as pm

# This is a bogus model that reproduces the error
n = 4
with pm.Model() as model:
    p = pm.Dirichlet("p", a=pt.ones(n))
    dft = pm.Deterministic("dft", fft.rfft(pt.shape_padaxis(p, 0))) # returns shape (1,n//2+1,2)
    dftreshape = pm.Deterministic("dftreshape", dft[0,:,0]) # returns shape (n//2+1,)
    L = pm.Potential("L", pt.sum(dftreshape)) # sum returns shape ()
samples = pm.sample(draws=1024, model=model)

Below is the resultant traceback.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 12
     10     dftreshape = pm.Deterministic("dftreshape", dft[0,:,0]) # returns shape (n//2+1,)
     11     L = pm.Potential("L", pt.sum(dftreshape)) # sum returns shape ()
---> 12 samples = pm.sample(draws=1024, model=model)

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pymc/sampling/mcmc.py:716, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    713         auto_nuts_init = False
    715 initial_points = None
--> 716 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    718 if nuts_sampler != "pymc":
    719     if not isinstance(step, NUTS):

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pymc/sampling/mcmc.py:223, in assign_step_methods(model, step, methods, step_kwargs)
    221 if has_gradient:
    222     try:
--> 223         tg.grad(model_logp, var)  # type: ignore
    224     except (NotImplementedError, tg.NullTypeGradError):
    225         has_gradient = False

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:607, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    604     if hasattr(g.type, "dtype"):
    605         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 607 _rval: Sequence[Variable] = _populate_grad_dict(
    608     var_to_app_to_idx, grad_dict, _wrt, cost_name
    609 )
    611 rval: MutableSequence[Variable | None] = list(_rval)
    613 for i in range(len(_rval)):

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1400, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1397     # end if cache miss
   1398     return grad_dict[var]
-> 1400 rval = [access_grad_cache(elem) for elem in wrt]
   1402 return rval

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1355, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1353 for node in node_to_idx:
   1354     for idx in node_to_idx[node]:
-> 1355         term = access_term_cache(node)[idx]
   1357         if not isinstance(term, Variable):
   1358             raise TypeError(
   1359                 f"{node.op}.grad returned {type(term)}, expected"
   1360                 " Variable instance."
   1361             )

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1032, in _populate_grad_dict.<locals>.access_term_cache(node)
   1029 if node not in term_dict:
   1030     inputs = node.inputs
-> 1032     output_grads = [access_grad_cache(var) for var in node.outputs]
   1034     # list of bools indicating if each output is connected to the cost
   1035     outputs_connected = [
   1036         not isinstance(g.type, DisconnectedType) for g in output_grads
   1037     ]

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1355, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1353 for node in node_to_idx:
   1354     for idx in node_to_idx[node]:
-> 1355         term = access_term_cache(node)[idx]
   1357         if not isinstance(term, Variable):
   1358             raise TypeError(
   1359                 f"{node.op}.grad returned {type(term)}, expected"
   1360                 " Variable instance."
   1361             )

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1032, in _populate_grad_dict.<locals>.access_term_cache(node)
   1029 if node not in term_dict:
   1030     inputs = node.inputs
-> 1032     output_grads = [access_grad_cache(var) for var in node.outputs]
   1034     # list of bools indicating if each output is connected to the cost
   1035     outputs_connected = [
   1036         not isinstance(g.type, DisconnectedType) for g in output_grads
   1037     ]

    [... skipping similar frames: _populate_grad_dict.<locals>.access_grad_cache at line 1355 (2 times), _populate_grad_dict.<locals>.access_term_cache at line 1032 (1 times)]

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1032, in _populate_grad_dict.<locals>.access_term_cache(node)
   1029 if node not in term_dict:
   1030     inputs = node.inputs
-> 1032     output_grads = [access_grad_cache(var) for var in node.outputs]
   1034     # list of bools indicating if each output is connected to the cost
   1035     outputs_connected = [
   1036         not isinstance(g.type, DisconnectedType) for g in output_grads
   1037     ]

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1355, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1353 for node in node_to_idx:
   1354     for idx in node_to_idx[node]:
-> 1355         term = access_term_cache(node)[idx]
   1357         if not isinstance(term, Variable):
   1358             raise TypeError(
   1359                 f"{node.op}.grad returned {type(term)}, expected"
   1360                 " Variable instance."
   1361             )

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/gradient.py:1185, in _populate_grad_dict.<locals>.access_term_cache(node)
   1177         if o_shape != g_shape:
   1178             raise ValueError(
   1179                 "Got a gradient of shape "
   1180                 + str(o_shape)
   1181                 + " on an output of shape "
   1182                 + str(g_shape)
   1183             )
-> 1185 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1187 if input_grads is None:
   1188     raise TypeError(
   1189         f"{node.op}.grad returned NoneType, expected iterable."
   1190     )

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/graph/op.py:398, in Op.L_op(self, inputs, outputs, output_grads)
    371 def L_op(
    372     self,
    373     inputs: Sequence[Variable],
    374     outputs: Sequence[Variable],
    375     output_grads: Sequence[Variable],
    376 ) -> list[Variable]:
    377     r"""Construct a graph for the L-operator.
    378 
    379     The L-operator computes a row vector times the Jacobian.
   (...)
    396 
    397     """
--> 398     return self.grad(inputs, output_grads)

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/tensor/elemwise.py:290, in DimShuffle.grad(self, inp, grads)
    287     return [inp[0].zeros_like(dtype=config.floatX)]
    288 else:
    289     return [
--> 290         DimShuffle(tuple(s == 1 for s in gz.type.shape), grad_order)(
    291             Elemwise(scalar_identity)(gz)
    292         )
    293     ]

File /Applications/anaconda3/envs/ACES/lib/python3.12/site-packages/pytensor/tensor/elemwise.py:175, in DimShuffle.__init__(self, input_broadcastable, new_order)
    172             drop.append(i)
    173         else:
    174             # We cannot drop non-broadcastable dimensions
--> 175             raise ValueError(
    176                 "Cannot drop a non-broadcastable dimension: "
    177                 f"{input_broadcastable}, {new_order}"
    178             )
    180 # This is the list of the original dimensions that we keep
    181 self.shuffle = [x for x in new_order if x != "x"]

ValueError: Cannot drop a non-broadcastable dimension: (False, False), [1]

Seems to be a PyTensor bug. This raises the same error:

import pytensor.tensor as pt

p = pt.vector("p", shape=(4,))
out = pt.fft.rfft(p[None, :])
pt.grad(out.sum(), p)

Opened an issue here: Bug in gradient of fft.rfft · Issue #969 · pymc-devs/pytensor · GitHub