Hi,
You want to use pytensor.scan, which is a differentiable loop. Docs here, some discussion and example using a scan together with a CustomDist (which you will want to do) here.
pytensor.scan