How to handle loops better?

I am trying to do some Bayesian analysis for models like markovian switching multi-fractal. Basically, at each state, we decide to switch one volatility component to a new value(sample) or not. Currently, I use a naive approach to loop through the time and create a volatility variable for each time t.

import pymc as pm
import pytensor.tensor as pt

with pm.Model() as MSM:
    phi = pm.Uniform("phi",lower=0,upper=1)
    sig = pm.Gamma("sig", alpha=2, beta=1)
    gamma_A = pm.Uniform("gamma_A",lower=0,upper=1)
    rm_A = pm.HalfNormal("rm_A", sigma=0.1)
    rA = rm_A + 1
    a = pm.Gamma("a", alpha=2, beta=1)

    gamma_A1 = gamma_A
    gamma_A2 = gamma_A ** rA
    gamma_A3 = gamma_A ** (rA * rA)

    PA1 = pm.Bernoulli("PA1", p=gamma_A1, shape=(T-1,))
    PA2 = pm.Bernoulli("PA2", p=gamma_A2, shape=(T-1,))
    PA3 = pm.Bernoulli("PA3", p=gamma_A3, shape=(T-1,))

    A1 = pm.LogNormal("A1", mu=-a**2/2, sigma=a, shape=(T,))
    A2 = pm.LogNormal("A2", mu=-a**2/2, sigma=a, shape=(T,))
    A3 = pm.LogNormal("A3", mu=-a**2/2, sigma=a, shape=(T,))

    A1_t, A2_t, A3_t, B1_t, B2_t, B3_t = [A1[0]],[A2[0]],[A3[0]],[B1[0]],[B2[0]],[B3[0]]
    for i in range(1,T):
        A1_t.append(A1_t[i-1]*(PA1[i-1]) + A1[i] * (1-PA1[i-1]))
        A2_t.append(A2_t[i-1]*(PA2[i-1]) + A2[i] * (1-PA2[i-1]))
        A3_t.append(A3_t[i-1]*(PA3[i-1]) + A3[i] * (1-PA3[i-1]))

    A1_t = pt.stack(A1_t)
    A2_t = pt.stack(A2_t)
    A3_t = pt.stack(A3_t)
    A_t = A1_t * A2_t * A3_t

    sig_t = sig * pt.sqrt(A_t)

    yobs = pm.Normal("yobs", mu=phi*x, sigma=sig_t, observed=y)

    posterior = pm.sample(10000, tune=1000)

However, this only works in small time series, up to about 50 samples. When I try to work with series of hundreds samples, the program will throw ‘UserWarning: Loop fusion failed because the resulting node would exceed the kernel argument limit’. I would like to know if there is better ways to do this kind of loops in pymc, and perhaps improve performance?

For large loops in PyMC you should use PyTensor’s Scan: scan – Looping in PyTensor — PyTensor dev documentation