Several minibatch parameters

Hi! I have a regression which I fit using advi (nuts sampling is extremely slow), and would like to try minibatches because the data size is ~10^5-10^6. With plain advi fitting it reaches reasonable convergence in several hours.

The model looks like this:

with pm.Model() as model:
    G0_amp = pm.Lognormal('G0_amp', 0, 1, shape=(n_1, n_2, n_3, n_4))
    G0_phas = pm.VonMises('G0_phas', mu=0, kappa=1/1**2, shape=(n_1, n_2, n_3, n_4))
    
    gain_amp = G0_amp[ix1[:, 0], ix2, :, :] * G0_amp[ix1[:, 1], ix2, :, :]
    gain_phas = G0_phas[ix1[:, 0], ix2, :, :] - G0_phas[ix1[:, 1], ix2, :, :]

    # ... computing mu from gain_amp and gain_phas ...
    
    obs_r = pm.StudentT('obs_r', mu=mu[~mask], sd=sd, nu=2, observed=data[~mask].real)
    obs_i = pm.StudentT('obs_i', mu=mu[~mask], sd=sd, nu=2, observed=data[~mask].imag)

So, the parameters to minibatch are ix1 (shape n*2), ix2 (shape n), data, mask (shape n*n_3*n_4 for both). The above code works ok, but I’d like to speed it up using minibatching across the n dimension.

Trying the very intuitive code:

bs = 1000
ix1_mb = pm.Minibatch(ix1, batch_size=bs)
ix2_mb = pm.Minibatch(ix2, batch_size=bs)
data_r_mb = pm.Minibatch(data.real, batch_size=bs)
data_i_mb = pm.Minibatch(data.imag, batch_size=bs)
mask_mb = pm.Minibatch(mask, batch_size=bs)

# replace everywhere in the model <param> with <param>_mb

doesn’t seem to work: it gives way higher loss and the resulting trace is very different from the first model. My guess is that it’s probably because different minibatches give different random samples, and not those corresponding to the same indices, but this may be wrong - I’m not sure.

What should be the correct way to do this?

Usually the minibatch are synchronised, at least for vector with the same length. You can validate it with something like:

X = np.repeat(np.arange(100), 5).reshape(100, 5)
y = np.arange(100)
xm = pm.Minibatch(X)
ym = pm.Minibatch(y)
print(xm.eval(), ym.eval())

However, in your case since there are multiple dimension, you should check the minibatch by doing .eval(). I feel like since you are minibatching the mask maybe you dont need to minibatch the data (data_r_mb and data_i_mb not needed)?

But if I don’t minibatch the data - how to apply the mask? mask_mb is basically 500x8x2 and data_r is like 30000x8x2.

Yeah I am a bit confuse by the mask here - If in the original model the observed is data[~mask]... why not just use data = data[~mask]?

Well, the point is that data and mask have shape n*8*2, and this is the reasonable shape for mu in model according to its construction process (gain_* also have the shape n*8*2). And data[~mask] (as well as mu[~mask]) is a single-dimensional vector with size equals to zeros count in mask, which makes it impossible to minibatch data[~mask] together with e.g. ix2 (which is n-sized).

I see… maybe something like:

data_shared = theano.shared(data)
ix1shared = theano.shared(ix1)
ix2shared = theano.shared(ix2)
maskshared = theano.shared(~mask)

ridx = pm.tt_rng().uniform(size=(batchshape,), low=0, high=data.shape[0]-1e-10).astype(‘int64’)

then replace in your code:
ix1[:, 0], ix2 --> ix1shared[ridx, 0], ix2[ridx]
mu[~mask] --> mu[mask[ridx, :, :]]
data[~mask] --> data[maskshared[ridx, :, :]]

This is to ensure correct indices, right? And pm.tt_rng somehow generates different sample at each iteration, but not each time it is used in code?
Also, do you mean data[ridx][maskshared[ridx]] in the last line?

I think this would get use the same random index at each evaluation, which should produce the synchronized behaviour.

No, I would only minibatch the mask as that already controls the indexing to the data and mu.

the same random index at each evaluation

But it seems like ridx gets evaluated multiple times, namely when indexing each of the arrays.

No, I would only minibatch the mask as that already controls the indexing to the data and mu.

Then I totally do not understand how that should work. maskshared has the same shape as mask, so n*8*2. Now, maskshared[ridx, :, :] has shape 500*8*2, right? But data still has n*8*2, so how can we index data with maskshared[ridx, :, :]?

It was used multiple time but when you evaluate the tensor that represents the loss (KLqp) the computation is sync (not completely sure about this, so please verify it on your side).

You are right, I was thinking about an indexing mask but not a 0-1 mask. Yes in that case you do need to also index data first.

Thanks, will try with this. However your example

X = np.repeat(np.arange(100), 5).reshape(100, 5)
y = np.arange(100)
xm = pm.Minibatch(X)
ym = pm.Minibatch(y)
print(xm.eval(), ym.eval())

confirms that minibatches should be synchronized, so maybe there is something wrong on my side.

1 Like