Error while finetuning random intercept model using ADVI

I played with the code and got it to work.

np.random.seed(42)

m = 1000
n = 100

# Regressors
X = norm.rvs(loc=0, scale=1, size=m)

# Parameters
a_true = 0.1
b_true = 0.5

# Random intercepts
g_true = norm.rvs(loc=0, scale=0.1, size=m)

# Each row has a different mean
y = norm.rvs(loc=a_true + b_true * X[:, np.newaxis] + g_true[:, np.newaxis], scale=1, size=(m, n))

# Minibatches
bs = 100
indices = np.arange(m)

# Model
with pm.Model() as model:
    # Data containers for X and y
    X_t = pm.Data("X_t", X[:bs])  # Start with a minibatch of size `bs`
    y_t = pm.Data("y_t", y[:bs])
    
    # Define the priors
    a = pm.Normal("a", sigma=1)
    b = pm.Normal("b", sigma=1)
    g = pm.Normal("g", sigma=0.1, size=m)  # Random intercepts, one for each individual
    sigma = pm.HalfNormal("sigma", sigma=1)
    
    # Properly index g with the batch indices for minibatch training
    idx = pm.Data("idx", indices[:bs])  # Index for the minibatch
    
    # Define the mean function with proper broadcasting
    mu = a + b * X_t[:, None] + g[idx][:, None]  # Use idx to slice the random intercepts correctly
    
    # Likelihood
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y_t)
    
    # Fit model using ADVI with minibatches
    mean_field = pm.fit(n=25000, method="advi")
    
    # Update data containers with full dataset for fine-tuning
    pm.set_data({"X_t": X, "y_t": y, "idx": indices})  # Replace minibatch data with full data
    mean_field = pm.fit(n=1000, method="advi")

post = mean_field.sample(2000)

with model:
    
    ppc = pm.sample_posterior_predictive(post)

The PPC check looks good too.

image

Hope that works for you.