is there an equivalent of TensorFlow tf.segment_sum, or tf.unsorted_segment_sum?
I see there was conversation about adding this in PyTorch: Pytorch equivalent to tf.unsorted_segment_sum
- PyTorch Forums
is there an equivalent of TensorFlow tf.segment_sum, or tf.unsorted_segment_sum?
I see there was conversation about adding this in PyTorch: Pytorch equivalent to tf.unsorted_segment_sum
- PyTorch Forums
No, but you can do it with a PyTensor scan
import pytensor
import pytensor.tensor as pt
data = pt.as_tensor([5, 1, 7, 2, 3, 4, 1, 3])
segment_ids = pt.as_tensor([0, 0, 0, 1, 2, 2, 3, 3])
n_segments = segment_ids[-1] + 1 # or segment_ids.max() + 1 if they are not sorted
out = pt.zeros(n_segments)
segment_sum_scan, _ = pytensor.scan(
fn=lambda datum, segment_id, out: out[segment_id].inc(datum),
sequences=[data, segment_ids],
outputs_info=[out],
)
segment_sum = segment_sum_scan[-1]
# debug eval
segment_sum.eval() # array([13., 2., 7., 4.])
Actually that was overkill, just do:
segment_sum = pt.zeros(n_segments+1)[segment_ids].inc(data)
# debug eval
segment_sum.eval() # array([13., 2., 7., 4.])
Why can I not find .inc() anywhere in the documentation?
Thanks anyways, it works!
Because our documentation of the variable methods is not great. inc
method is equivalent to the pt.inc_subtensor
function which you should be able to find in the docs