Speed up the model: how large is a model large enough to benefit from using GPU?

Dear all,
I currently have a hierarchical model (I can put it out if necessary) with data of dimension (20, x, 1000), where x can be at top several thousands. With a subset of data of dimension (20, 20, 5) it took 20 min and with (20, 500, 1000) and it ran for more than 1 day so I stopped (it could be that my model is structured in an inefficient way). From my limited knowledge, I am considering GPU (need to apply for resources), minibatch and VI. I have read the dev notes and understand GPU is not recommanded when the model is small.
So my two questions are: 1) Is my model large enough to benefit from GPU at any chance, or should I consider minibatch, VI or improving my model? 2) Are other PPLs such as TFP theoretically (in terms of memory copying or such) different from pymc so that GPU is worth trying in my case?
Thanks in advance.

Very difficult to say without knowing more about your model. What makes it take 20 minutes with 2k datapoints? Does the sampler complain with divergence or non convergence?

Here is the model. Dimensions are given tailing each line. Basically it’s a Gamma GLM followed by a Poisson. Three dimensions of observed data are associated through 2 levels of pooling. An identity link function is used for now.

TD = pm.MutableData('TD', dataDict['TD'])  # (nSample, nPos, nComo)
AD = pm.MutableData('AD', dataDict['AD'])  # (nSample, nPos, nComo)

coeff_cov1 = pm.HalfNormal('coeff_cov1', sigma=1, shape=(nCov1,))  # (nCov1,)
coeff_cov2 = pm.HalfNormal('coeff_cov2', sigma=1, shape=(nCov2,))  # (nCov2,)
coeff_cov3 = pm.HalfNormal('coeff_cov3', sigma=1, shape=(nCov3,))  # (ncov3,)
coeff_cov4 = pm.HalfNormal('coeff_cov4', sigma=1, shape=(nCov4,))  # (nCov4,)
intercept = pm.HalfNormal('intercept', sigma=1, shape=1)

# Below is to obtain nComo combinations of above covariates, some of which are categorical others discrete.
# Each observed datum corresponds to one combo.
# Each element of dataDict['compIndices'] is a list of indices, each element of which is the index of each covariate
# I chose this way of implementation because the 4 covariates are in different dimensions.
gammaMu_c = pm.Deterministic("gammaMu_c", at.stack([at.math.sum(
            intercept
            + coeff_cov1[indices[0]] * cov1
            + coeff_cov2[indices[1]] * cov2
            + coeff_cov3[indices[2]] * cov3
            + coeff_cov4[indices[3]]) * cov4
            for indices in dataDict['comb_indices']
             ]))  # (nComo,)

# pooling level 1: each combo-level Mu populates to position-level Mus
gammaMu_p = pm.Normal("gammaMu_p", mu=gammaMu_c, sigma=1, shape=(nPos, nComo))  # (nPos, nComo)
# pooling level 2: each position-level Mu populates to sample-level Mus
gammaMu_s = pm.Normal("gammaMu_p", mu=gammaMu_p, sigma=1, shape=(nSample, nPos, nComo))  # (nSample, nPos, nComo)
gammaStd = pm.HalfNormal("gammaStd", sigma=5, shape=1)  # (1,)
gammaShape = pm.Deterministic('gammaShape', gammaMu_s ** 2 / gammaStd ** 2)  # (nSample, nPos, nComo)
gammaBeta = pm.Deterministic('gammaBeta', (gammaShape / gammaMu_s))  # (nSample, nPos, nComo)

AF = pm.Gamma('AF', alpha=gammaShape, beta=gammaBeta)  # (nSample, nPos, nComo)

lambda_p = pm.Deterministic('lambda_p', TD * AF)  # (nSample, nPos, nComo)
# mask out missing data using a binary mask
obs = pm.Potential('obs', pm.logp(pm.Poisson.dist(mu=lambda_p), AD) * dataDict['AD_mask'])

I use the following code to generate toy data ignoring missing data:
from scipy.stats import halfnorm
rng = np.random.default_rng(123)
nSample, nPos, nCombo = 20, 20, 10
gammaMu_c_true = halfnorm.rvs(loc=2, scale=.5, size=nCombo)
gammaMu_p_true = rng.normal(loc=gammaMu_c_true, scale=.05, size=(nPos, nCombo))
gammaMu_s_true = rng.normal(loc=gammaMu_p_true, scale=.1, size=(nSample, nPos, nCombo))

gammaMu_s_true = gammaMu_s_true # identity link
gammaStd_true = halfnorm.rvs(loc=0, scale=2, size=(nSample, nPos, nCombo))
gammaShape_true = gammaMu_true ** 2 / gammaStd_true ** 2
gammaBeta_true = gammaShape_true / gammaMu_true
AF_true = rng.gamma(shape=gammaShape_true, scale=1 / gammaBeta_true)
TD_true = np.random.randint(2000, 4000, (nSample, nPos, nCombo))
lambda_p_true = TD_true * AF_true
AD_true = rng.poisson(lambda_p_true)

With 2 chains 1000 draws, there were around 20 divergences.
Using nSample, nPos, nCombo = 20, 20, 10, it roughly converged ok by posterior plots.

By the way you don’t need to wrap everything in deterministics, only if you need those values afterwards. It will save on memory/computation.

You might benefit from changing how you compute gammaMu_c, but as you commented it might not be trivial.

Otherwise it looks fine, but the fact you get divergences could mean bad priors/ challenging model. I would focus on fixing that before worrying about speedup. As you would still be getting invalid samples, just faster.