Out of memory when "transforming variables" in Numpyro & JAX

I am able to finish the sampling process on a large model using the JAX backend with a GPU, but it fails during the “transforming variables” step because it requires too much RAM. I set postprocessing_backend="cpu" which has a lot more RAM than my GPU, but it still runs out.

Looking at the code in the pymc.sampling_jax.sample_numpyro_nuts, it seems to be happening here:

def sample_numpyro_nuts():
    # ...
    print("Transforming variables...", file=sys.stdout)
    jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    result = jax.vmap(jax.vmap(jax_fn))(
        *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
    )
    mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

    tic4 = datetime.now()
    print("Transformation time = ", tic4 - tic3, file=sys.stdout)
    # ...

Is there a step that is the likely culprit? (I’ve modified the code on my machine to print out debug statements between each of these lines, so I should be able to figure it out in a couple of hours.) Is there something I can do with the current version of PyMC to address this?

Thank you!

Update: from my testing, it is happening during the vmap step:

result = jax.vmap(jax.vmap(jax_fn))(
    *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
)

where postprocessing_backend = "cpu". Is there some way to make this process require less RAM?

CC @ferrine

The only way to use less RAM as I know so far is to reduce the amount of deterministic variables

I’ll definitely try that! Thank you!

Also, disabling the loglikelihood computation via idata_kwargs if you are not planning to do model comparison can help if your model has many observations

Thank you for the recommendation. I don’t think it would fix this problem, though, because those parameters aren’t used at the step that is failing. (Of course, you know the code base far better than me, so please correct me if I’m wrong.) But I have seen this suggested for other RAM problems, so I’ll definitely keep it in mind. So far, reducing the number of Deterministic variables has helped a lot.

2 Likes

@ferrine made very little difference for me.

i’m limited to running very small chains.

an option to use jax.lax.map — JAX documentation should save memory usage and stop the memory spike

in fact this seems to work:

from jax.experimental.maps import SerialLoop, xmap
num_chunks = 10 # must be a multiple of number of samples
mapper = xmap(jax_fn, 
              in_axes=['chain', 'samples', ...], 
              out_axes=["chain", 'samples', ...], 
              axis_resources={"samples": SerialLoop(num_chunks)})
loop = xmap(mapper, 
            in_axes=[...], 
            out_axes=[...])
result = loop(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))

Massively decreases memory usage and returns the same result as vmap.

2 Likes

This is really nice!! How is the speed compare to vmap?

Would you like to send a PR to PyMC?

1 Like

As far as I could tell the speed was the same.

Sure :+1:

1 Like