Error while finetuning random intercept model using ADVI

Hi everyone, I am trying to fit a basic random intercept model using ADVI. I am using minibatches and then finetuning on the full dataset. However, the finetuning gives me a dimension mismatch.

np.random.seed(42)

m = 1000
n = 100

# regressors
X = norm.rvs(loc=0, scale=1, size=m)

# parameters
a_true = 0.1
b_true = 0.5
# random intercepts
g_true = norm.rvs(loc=0, scale=0.1, size=m)

# each row has a different mean
y = norm.rvs(loc=a_true + b_true * X[:, np.newaxis] +g_true[:, np.newaxis], scale=1, size=(m,n))

# minibatches
bs = 100
indices = np.arange(m)
X_t, y_t, idx = pm.Minibatch(X, y, indices, batch_size=bs)

model = pm.Model()

with model:
    a = pm.Normal("a", sigma=1)
    b = pm.Normal("b", sigma=1)
    g = pm.Normal("g", sigma=0.1, size=m)
    sigma = pm.HalfNormal("sigma", sigma=1)
    mu = a + b * X_t[:, None] + g[idx][:, None]
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y_t, total_size=y.shape)
    ### minibatches
    mean_field = pm.fit(n=25_000,
        method="advi",
    )
    ### finetuning
    mean_field = pm.fit(n=1_000, method="advi", more_replacements={X_t: X, y_t: y, idx: indices})

post = mean_field.sample(2000)

The minibatch runs fine, however, the finetuning gives me an error relating to a dimension mismatch:

Finished [100%]: Average Loss = 14,936
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Assert{msg='Could not broadcast dimensi...'}(1000, False)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1909, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1081, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 1123, in constant_folding
    required = thunk()
               ^^^^^^^
  File ".../lib/python3.11/site-packages/pytensor/graph/op.py", line 524, in rval
    r = p(n, [x[0] for x in i], o)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/pytensor/raise_op.py", line 105, in perform
    raise self.exc_type(self.msg)
AssertionError: Could not broadcast dimensions. Broadcasting is only allowed along axes that have a statically known length 1. Use `specify_broadcastable` to inform PyTensor of a known shape.

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_fill_to_alloc
ERROR (pytensor.graph.rewriting.basic): node: Second(sigma > 0, True_div.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):

...
  File ".../lib/python3.11/site-packages/pytensor/raise_op.py", line 105, in perform
    raise self.exc_type(self.msg)
AssertionError: Could not broadcast dimensions. Broadcasting is only allowed along axes that have a statically known length 1. Use `specify_broadcastable` to inform PyTensor of a known shape.

I played with the code and got it to work.

np.random.seed(42)

m = 1000
n = 100

# Regressors
X = norm.rvs(loc=0, scale=1, size=m)

# Parameters
a_true = 0.1
b_true = 0.5

# Random intercepts
g_true = norm.rvs(loc=0, scale=0.1, size=m)

# Each row has a different mean
y = norm.rvs(loc=a_true + b_true * X[:, np.newaxis] + g_true[:, np.newaxis], scale=1, size=(m, n))

# Minibatches
bs = 100
indices = np.arange(m)

# Model
with pm.Model() as model:
    # Data containers for X and y
    X_t = pm.Data("X_t", X[:bs])  # Start with a minibatch of size `bs`
    y_t = pm.Data("y_t", y[:bs])
    
    # Define the priors
    a = pm.Normal("a", sigma=1)
    b = pm.Normal("b", sigma=1)
    g = pm.Normal("g", sigma=0.1, size=m)  # Random intercepts, one for each individual
    sigma = pm.HalfNormal("sigma", sigma=1)
    
    # Properly index g with the batch indices for minibatch training
    idx = pm.Data("idx", indices[:bs])  # Index for the minibatch
    
    # Define the mean function with proper broadcasting
    mu = a + b * X_t[:, None] + g[idx][:, None]  # Use idx to slice the random intercepts correctly
    
    # Likelihood
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y_t)
    
    # Fit model using ADVI with minibatches
    mean_field = pm.fit(n=25000, method="advi")
    
    # Update data containers with full dataset for fine-tuning
    pm.set_data({"X_t": X, "y_t": y, "idx": indices})  # Replace minibatch data with full data
    mean_field = pm.fit(n=1000, method="advi")

post = mean_field.sample(2000)

with model:
    
    ppc = pm.sample_posterior_predictive(post)

The PPC check looks good too.

image

Hope that works for you.