No, but you can do it with a PyTensor scan
import pytensor
import pytensor.tensor as pt
data = pt.as_tensor([5, 1, 7, 2, 3, 4, 1, 3])
segment_ids = pt.as_tensor([0, 0, 0, 1, 2, 2, 3, 3])
n_segments = segment_ids[-1] + 1 # or segment_ids.max() + 1 if they are not sorted
out = pt.zeros(n_segments)
segment_sum_scan, _ = pytensor.scan(
fn=lambda datum, segment_id, out: out[segment_id].inc(datum),
sequences=[data, segment_ids],
outputs_info=[out],
)
segment_sum = segment_sum_scan[-1]
# debug eval
segment_sum.eval() # array([13., 2., 7., 4.])