Help to model a difficult logistic regression model with minibatches

Dear community,
I’m trying to model a problem that might be difficult to sample. I have a matrix filled with 0 and 1 and I have to model the relationship of each pixel with all the others, following this paper.

This is the input matrix (1000 x 1000 matrix)
47

This is the equation that models the pixels. I have to find gamma.

49

Since the distance matrix is too big for a PC (1000^2 x 1000^2 distance matrix) I thought it might be easier to use minibatches and compute just part of the matrix at every step.

M0 = np.load( open('ema/rybski_L_1024.npz', 'rb') )['arr_0'].astype('float32')
m = tt.as_tensor_variable(M0.reshape(-1, 1))
data = pm.Minibatch(M0.reshape(-1, 1), 1)
BB = pm.Minibatch(B, 1)

with pm.Model() as pooled_model:

    y = pm.Bound(pm.HalfNormal, lower=0.001, upper=6)('gamma', sd=2)
    # Euclidian distance
    d = tt.sqrt((BB ** 2).sum(1).reshape((BB.shape[0], 1)) + (B ** 2).sum(1).reshape((1, B.shape[0])) - 2 * BB.dot(B.T))
    # Normalization to avoid overflows
    d = d/d.sum()
    
    dgamma = tt.power(d, -y)
    # Sometimes overflows and nans
    dgamma = tt.switch(tt.isnan(dgamma), 0., dgamma)
    dgamma = tt.switch(tt.isinf(dgamma), 0., dgamma)
    
    qi = eu.dot(m)  / eu.sum() 

    # expected parameter
    p = pm.math.sigmoid(qi/qi.max())
    
    qi_obs = pm.Bernoulli('qi_obs', p=p, observed=data, total_size=len(M0.ravel()))

    pooled_trace = pm.sample(3000, njobs=1, tune=6000, progressbar=True, init="advi")

Unfortunately, this occurs to “NaN occurred in optimization.” in the newest pymc, while in the oldest it does not converge so much. Am I modeling the problem wrong? Can you help me, please?

Thank you

Minibatch might not work well here as you can have some batch contain very little data, as the observed is mostly sparse. Moreover, since the formulation is summing over all pair of pixels, and minibatch would only summing a few pixels (ie pixels in that batch), wouldnt you get different model?

The formulation reminds me a lot of Mean field theory and Ising Model - would a similar decomposition works here also?

My pymc formulation sums the batch over all pixels, so the model is still the same. However, yes, I might have problems of sparsity…

Mhh I am not familiar with these models, let me check and understand! Thank you!