Pymc, numpyro GPs and Transform RVs memory behaviour

While developing GP models with numpyro to sample via the GPU I’ve noticed somewhat strange memory behavior. After pymc.sampling.jax.sample_numpyro_nuts finishes running (very fast) it enters a ‘Transforming RVs’ phase witch triggers a massive memory spike. I’m not sufficiently familiar with the internals of pymc to know which transformation is being done and how (I think it might be the same one mentioned by the reparameterize=True parameter of pymc.gp.Latent().prior however this step seems to have massive memory requirements. I’ve empirically determined the amount of memory is roughly:

64(b) \times n_{chains}\times n_{draws}\times d(number\ of\ GPs)\times N^2\times 10^{-9}(GB\ b^{-1})

Is there a way to avoid this giant tensor? Perhaps do the transform in minibatches? If we want to build a model with many target variables, the memory requirements increase unreasonably i.e.

64(b) \times 2(chains)\times 1000(draws)\times 23(number\ of\ GPs)\times (500)^2\times 10^{-9}(GB\ b^{-1})\approxeq 507 GB\ !!

Even if it impossible to avoid this, it would be usefull (and pretty easy I think) to add a check for sufficient memory. It certainly was pretty unclear to me where all the XLA OOM errors were coming from. Here’s the model I’m working with for reference:

with pymc.Model() as binary_model:
    gps=[]
    gppriors=[]
    _,M=chem_indicators.shape
    η=pymc.Normal('η', mu=3.0, sigma=1.0)
    for target in binary_targets_tidy.columns:
        mean_func=pymc.gp.mean.Constant(c=0)
        σv = pymc.LogNormal(f'σv_{target}', mu=.0, sigma=.25)+.5
        λ = pymc.MvNormal(f'λ_{target}', mu=3.0*np.ones(M),cov=σv*np.eye(M),)
        σ = pymc.Normal(f'σ_{target}', mu=3.0,sigma=1.0)
        k = η**2*pymc.gp.cov.ExpQuad(M, ls=λ**2)
        noise = pymc.gp.cov.WhiteNoise(σ)
        κ = k+noise
        gp=pymc.gp.Latent(mean_func=mean_func, 
                          cov_func=κ)
        _f=gp.prior(f'_f_{target}', X_train.values)
        gps.append(gp)
        gppriors.append(_f)
    f = pymc.Deterministic('f', at.math.sigmoid(at.stack(*gppriors).T))
    for idx,target in enumerate(binary_targets_tidy.columns):
        y_obs = pymc.Bernoulli(f'y_obs_{target}',p=f[:,[idx]], observed=Y_train.loc[:,target] )

Could you maybe provide a bit more detail on how fast it samples? It looks like a pretty large model, so I’m a bit surprised if it samples very quickly. The reparameterization you note here shouldn’t be responsible for the issue you’re seeing, because the reparameterization is really about whether you use a centered or non-centered parameterization. It shouldn’t trigger anything after sampling is done.

One speedup I see is using a Normal vector instead of an MvNormal with a diagonal covariance matrix for the lengthscale prior. If you need a non-diagonal covariance I’d try the non-centered version there too. Here’s more info.

Is your error coming from here? It looks like this happens when Jax moves everything from the GPU to the CPU.

Maybe this is relevant? Out of memory when "transforming variables" in Numpyro & JAX - #11 by wnorcbrown

Sampling time varies slightly weirdly enough. If I downsample from the original \sim 415 points to about 30%, it samples in about 20mins/chain. Increasing the number of points past some unclear point causes the sampling time to jump to about 2 hours per chain. Sampling time doesn’t seem to monotonically increase sampling time. Rather, past some unclear number of points, sampling time jumps abruptly. If it isn’t then it would be vary strange as I’ve repeatedly noticed:

  1. GPU memory consumption is very low during sampling, and as soon as ‘Transforming RVs’ appears it jumps to near max
  2. Every time reparameterize=True I get hit with an OOM error. Every time repameterize=False no OOM errors appear. Specifically from postprocessin_backend='cpu' there are no error messages - the kernel just dies. Setting postprocessing_backend='gpu' raises an XLA OOM error (nvidia-smi seems to agree. All 11GBs of vram are used up at this point)
  3. The reported error mentioned shows roughly ... f64(2,N,N,1) (it might also have a sample dimention, I’ll have to check). I note 64 are the columns of the input data
  4. I’ve no idea if vmap has anything to do with the OOM error. The CPU transfer explanation doesn’t seem plausible to me, that transfer should happen regardless of whether reparameterize is True or not
  5. If by normal vector you mean something like pymc.Normal(..., shape=M) I’ve considered this. However I’d like each weight to have it’s own mean (there’s no reason in my case they should be the same) and in my understanding if I do pymc.Normal(..., shape=M) that gives the same expectation for all the weights.
  6. This post Does not seem applicable here. Splitting it into two deterministics for stack and sigmoid seems to have no effect in computation time or memory as far as I can tell

No it doesn’t. Normal(”name”, shape=(10,)) will create a vector of 10 independent normal variates. You can even give them different prior parameters Normal(”name”, mu=np.arange(10), shape=(10,))

This results in a much more efficient computational graph because we only need to represent one node instead of 10.

This is explained here: Distribution Dimensionality — PyMC 5.0.0 documentation

Good to know, I’ll look into implementing it. The number of variables doesn’t seem related to the memory error though, as swapping them out with a single scalar (or pymc.Normal RV) doesn’t seem to change anything about the memory errors

It’s a bit hard to figure out what might exactly be happening since your model isn’t simple and we can’t run it on our end and see what you’re seeing. Is it possible to boil it down to something that we could run?

I can say though regarding point 2 and repameterize. Reparameterize should basically always be True. It’s very strange that that’s what gives you the errors. As far as I know, at least with NUTS, the only reason to turn reparameterize to False is to show that it doesn’t work! Trying to figure out what’s up with this would be where I’d start.

You have many variables created like this in nested loops. Did you try vectorizing with batched shapes everywhere you could or just in one place?

@ricardoV94 I’m not sure how vectorizing would be done in this instance. The idea is I want multiple GPs, each with it’s own kernel and I can’t think of any way to get this behavior other than a loop.

@bwengals I will try, although the multiple GP’s seem key to getting the error, so aside from fixing all the kernel parameters I can’t think of another way to trim this down

All the variables that are not GPs can probably be created outside of the loop and indexed where needed, and the same goes for the likelihood (maybe?).

Also I didn’t notice, but you can consider not wrapping the contents of the Determinstic f (just write f = at.math.sigmoid(...). This way it won’t be computed and stored in the trace after sampling is over. If you need these values you can always try to recreate them later. This might go a long way in reducing the memory spike you’re observing