Memory Usage Spikes When Transforming Variables. Working on a Logit Model

I’m a newbie to PyMC and Bayesian Models. I’m currently working on a logit model. The model is fairly complex with lots of parameters. Simplifying the model isn’t an option because the goal is to replicate a legacy model that was developed in C++. The dataset I’m working with has about 60k rows. But my laptop with 32GB RAM can not process a sample that has more than 10k observations. Looking at the memory usage, when transforming variables starts, memory usage went from 15GB all the way to 32GB.

I’ve tried ‘numpyro’ and ‘blackjax’ sampler. Here is my code. I’d really appreciate any advices or suggestions. Thank you!

coords = {
    "dim1": dim1,
    "dim2": dim2,
    "dim3": dim3,
    "Variables": Variables,
    "obs": range(N),
}

with pm.Model(coords=coords) as model:

    mu = pm.Normal('mu', 0, 1000, dims="Variables")
    sigma = pm.ChiSquared('sigma', 10, dims="Variables")  

    InterceptL_0 = pm.Normal("InterceptL_0", mu[Variables.index("InterceptL_0")], sigma[Variables.index("InterceptL_0")]**-1, dims="dim3")
    InterceptF_0 = pm.Normal("InterceptF_0", mu[Variables.index("InterceptF_0")], sigma[Variables.index("InterceptF_0")]**-1, dims="dim3")

    InterceptC_M = pm.Normal("InterceptC_M", mu[Variables.index("InterceptC_M")], sigma[Variables.index("InterceptC_M")]**-1, dims=("dim1","dim3"))
    InterceptL_M = pm.Normal("InterceptL_M", mu[Variables.index("InterceptL_M")], sigma[Variables.index("InterceptL_M")]**-1, dims=("dim1","dim3"))
    InterceptF_M = pm.Normal("InterceptF_M", mu[Variables.index("InterceptF_M")], sigma[Variables.index("InterceptF_M")]**-1, dims=("dim1","dim3"))

    Var1_ = pm.Normal("Var1_", mu[Variables.index("Var1_")], sigma[Variables.index("Var1_")]**-1, dims="dim3")

    u_A = InterceptC_M[dim1_idx, dim3_idx] + Var1_[dim3_idx] * df['Var1_C']
    u_B = InterceptL_M[dim1_idx, dim3_idx] + Var1_[dim3_idx] * df['Var1_L']
    u_C = InterceptF_M[dim1_idx, dim3_idx] + Var1_[dim3_idx] * df['Var1_F']
    u_D = InterceptF_M[dim1_idx, dim3_idx] + Var1_[dim3_idx] * df['Var1_F']

    TT_stack = pm.math.stack([u_A,u_B,u_C,u_D]).T

    p_TT = pm.Deterministic("p_TT", pm.math.softmax(TT_stack, axis=1), dims=("obs", "dim2"))

    like_TT = pm.Categorical("like_TT", p=p_TT, observed=dv_TT, dims="obs")

    Intercept = pm.Normal("Intercept", mu[Variables.index("Intercept")], sigma[Variables.index("Intercept")]**-1, dims=("dim1","dim3"))
    v = pm.Normal("v", mu[Variables.index("v")], sigma[Variables.index("v")]**-1, dims="dim3")
    Var2_ = pm.HalfNormal("Var2_", sigma=sigma[Variables.index("Var2_")]**-1, dims='dim3')
    Var3_ = pm.Normal("Var3_", mu[Variables.index("Var3_")], sigma[Variables.index("Var3_")]**-1, dims='dim1')

    utilModel=[]
    for i in range(len(dim1)):
        exec('''u_model{} = Intercept[dim1_idx, dim3_idx] \
                            - Var2_[dim3_idx] * df['Var2_'] \
                            + Var3_[dim1_idx] * df['Var3_'] \
                            + v[dim3_idx] * pm.math.log(pm.math.sum(pm.math.exp(TT_stack),axis=1))'''.format(dim1[i]))
        exec("utilModel.append(u_model{})".format(dim1[i]))

    model_stack=pm.math.stack(utilModel).T

    p_model = pm.Deterministic("p_model", pm.math.softmax(model_stack, axis=1), dims=("obs", "dim1"))

    like_model = pm.Categorical("like_model", p=p_model, observed=dv_model, dims="obs")

    idata=pm.sample(nuts_sampler="pymc", 
                    random_seed=1033)