ADVI Minibatch slows down with increasing size of data

Hi all,

I would like to get some pointers for using ADVI Minibatch with large data sets (1-5M records). I wrote a hierarchical model that uses ADVI Minibatch and it runs pretty quickly when I use 20-50K records – at ~400 it/s.

However, if I try using the full data set, it slows down to ~5 it/s. I was under the impression that minibatch was just indexing my data set, so I would expect it to take longer in total (more iterations) but that the time per iteration wouldn’t be affected.

This is for doing BetaBinomial regression as you can see below in the code. X.shape ~ (n_records, n_features) with n_records ~1M and n_features ~20. There’s 30 groups (“clusters”) so len(set(cluster_idx)) = 30.

What am I missing?

def minibatch_beta_binomial_uvi(X, y, cluster_idx, batch_size=1500, model_name='default'):
    '''
    Run variational inference for hierarchical bayesian logistic regression based on binomial-count data
    Params:
        X: Design matrix (n_records x n_features)
        y: Binomial response variable array([(n_0, k_0), (n_1, k_1), ... ])
        cluster_idx: Array with the cluster index for each records
        batch_size: Optional batch size for Minibatching
        model_name: Optional name for the model
    '''

    # Get successes (k) and trials (n) from y
   k = np.expand_dims(y[:,1], axis=1)
    n = y.sum(axis=1)
    n = np.expand_dims(n, axis=1)
    
    # Number of clusters
    n_cluster_0 = len(set(cluster_idx))
    
    # Generate minibatches
    X_t = pm.Minibatch(X, batch_size=batch_size)
    k_t = pm.Minibatch(k, batch_size=batch_size)
    n_t = pm.Minibatch(n, batch_size=batch_size)
    cluster_idx_0_t = pm.Minibatch(cluster_idx, batch_size=batch_size)

    with pm.Model() as beta_binomial_model:

        # Intercept Priors
        sd_intr = pm.HalfCauchy('intr_sd', beta=2.5)
        mu_intr = pm.Cauchy(name='mu_intr', alpha=0, beta=5)
        b_intr = pm.Normal('b_intr', mu=mu_intr, sd=sd_intr, shape=(n_cluster_0, 1))

        # Weights priors
        sigma_m = pm.HalfCauchy('sigma_m', beta=5, shape=(n_cluster_0, 1))
        mu_m = pm.Normal('mu_m', mu=0, sd=5 ** 2, shape=(n_cluster_0, 1))
     
        b_u = pm.Normal(name='b_u', mu=mu_m[cluster_idx_0_t], sd=sigma_m[cluster_idx_0_t], shape=(batch_size, X.shape[1])
        u_mu_arg = tt.reshape((X_t * b_u).sum(axis=1), (batch_size, 1))

        # Logit transformation
        mu = pm.math.invlogit(b_intr[cluster_idx_0_t] + u_mu_arg)
        
        # Hyperpriors on Beta parameters
        # Here, the beta distribution is reparametrized by mu and kappa (the population mean and the "sample size" -proportional to inverse of the variance– respectively)
        kappa_log = pm.Exponential('kappa_log', lam=1.5)
        kappa = pm.Deterministic('kappa', tt.exp(kappa_log))
        alpha = pm.Deterministic('alpha', mu*kappa)
        beta = pm.Deterministic('beta', (1-mu)*kappa)

        yobs_name = 'Y_obs_' + model_name
        pm.BetaBinomial(yobs_name, n=n_t, alpha=alpha, beta=beta, observed=k_t, total_size=X.shape)
    
    with beta_binomial_model:
        approx = pm.fit(10000, method='advi')
    
    return approx

Interesting, I have not seen this before. I guess you can profile the function to find out why. @ferrine do you have any idea?

I profiled the function and it looks like the slowness is coming from Theano (Op Name: AdvancedSubtensor1) – probably related to my matrix multiplication. But not sure why this would happen since the shape X_t does not change as the dataset grows in size (it’s always batch_size x n_features).

Interesting note: If I increase batch_size (by say, a factor or 5), the inference speeds up significantly.

Having a hard time making sense of all of this still.

1 Like

I read somewhere that it’s the generation of the indices for Minibatch that increases. When there’s a lot of data, sampling the 1000 random rows for minibatch takes a long time for some reason?

I can’t find where I read this, and I’m not sure if it’s true. I’ll try a simple example and see if this still happens