PyTensor version of TensorFlow segment_sum

is there an equivalent of TensorFlow tf.segment_sum, or tf.unsorted_segment_sum?

I see there was conversation about adding this in PyTorch: Pytorch equivalent to tf.unsorted_segment_sum - PyTorch Forums

No, but you can do it with a PyTensor scan

import pytensor
import pytensor.tensor as pt

data = pt.as_tensor([5, 1, 7, 2, 3, 4, 1, 3])
segment_ids = pt.as_tensor([0, 0, 0, 1, 2, 2, 3, 3])
n_segments = segment_ids[-1] + 1  # or segment_ids.max() + 1 if they are not sorted
out = pt.zeros(n_segments)
segment_sum_scan, _ = pytensor.scan(
  fn=lambda datum, segment_id, out: out[segment_id].inc(datum),
  sequences=[data, segment_ids],
  outputs_info=[out],
)
segment_sum = segment_sum_scan[-1]

# debug eval
segment_sum.eval()  # array([13.,  2.,  7.,  4.])

Actually that was overkill, just do:

segment_sum = pt.zeros(n_segments+1)[segment_ids].inc(data)

# debug eval
segment_sum.eval()  # array([13.,  2.,  7.,  4.])
2 Likes

Why can I not find .inc() anywhere in the documentation?

Thanks anyways, it works!

Because our documentation of the variable methods is not great. inc method is equivalent to the pt.inc_subtensor function which you should be able to find in the docs

1 Like