For recursive computations that are long enough you’ll probably need to use Scan: scan – Looping in PyTensor — PyTensor dev documentation
A different nuts sampler like numpyro may be needed to stay performant: Faster Sampling with JAX and Numba — PyMC example gallery