Mcbackend memory fail

In trying to use mcbackend to store sample from big model, I seem to have run into exactly the issue I’m trying to solve, namely memory problems with the full trace. @michaelosthege is this something you have encountered before? Key parts of the model and the error message are below. Thanks much! (ping @AlexAndorra)

ch_client = clickhouse_driver.Client("localhost")
ch_backend = mcbackend.ClickHouseBackend(ch_client)

with pm.Model(coords=COORDS) as model_x:
    ...

with model_x:
    pm.sample(1000,trace=ch_backend)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [PsppIdent_odds, latent_log_species_landings_, species_sigma_year, TaxonAGG_odds_, taxon_sigma_year, intercept, importer_species_effect, dyad_effect, trade_effect, sd_intercept_sd, lsd_intercept, sd_unreliability_sd, lsd_unreliability_effect]
Traceback (most recent call last):██████████████████████████████████████████████████████████| 100.00% [8000/8000 1:42:43<00:00 Sampling 4 chains, 0 divergences]
  File "/Users/aaronmacneil/Dropbox/My Mac (Mac-mini.local)/Documents/GitHub/Global_Shark_Meat/notebooks/Joint_Trade_Landings_Model.py", line 364, in <module>
    pm.sample(1000,trace=ch_backend)
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 827, in sample
    return _sample_return(
           ^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 861, in _sample_return
    mtrace = MultiTrace(traces)[:length]
             ~~~~~~~~~~~~~~~~~~^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/pymc/backends/base.py", line 370, in __getitem__
    return self._slice(idx)
           ^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in _slice
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in <listcomp>
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                  ^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 194, in _slice
    draw = self._chain.get_draws_at(i, var_names=vnames)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/Dropbox/My Mac (Mac-mini.local)/Documents/GitHub/mcbackend-0.5.2/mcbackend/backends/clickhouse.py", line 291, in get_draws_at
    return self._get_row_at(idx, var_names)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/Dropbox/My Mac (Mac-mini.local)/Documents/GitHub/mcbackend-0.5.2/mcbackend/backends/clickhouse.py", line 236, in _get_row_at
    data = self._client.execute(query)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/clickhouse_driver/client.py", line 373, in execute
    rv = self.process_ordinary_query(
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/clickhouse_driver/client.py", line 571, in process_ordinary_query
    return self.receive_result(with_column_types=with_column_types,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/clickhouse_driver/client.py", line 204, in receive_result
    return result.get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/clickhouse_driver/result.py", line 50, in get_result
    for packet in self.packet_generator:
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/clickhouse_driver/client.py", line 220, in packet_generator
    packet = self.receive_packet()
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmacneil/mambaforge/envs/gsm-project/lib/python3.11/site-packages/clickhouse_driver/client.py", line 237, in receive_packet
    raise packet.exception
clickhouse_driver.errors.ServerException: Code: 241.
DB::Exception: Memory limit (total) exceeded: would use 64.98 GiB (attempt to allocate chunk of 34359869503 bytes), maximum: 57.60 GiB. OvercommitTracker decision: Query was selected to stop by OvercommitTracker.: while reading column export_proportion at ./store/9e5/9e596fd0-8db6-499c-ba4d-219fa128a121/: While executing Log. Stack trace:

<Empty trace>

The exception sounds like it’s happening after the MCMC completed, when PyMC automatically converts to inferemcedata. I’ll have to check the code to see if you can bypass this step by passind “return_inferencedata=False, conpute_convergence_diagnostics=False”.

Otherwise the trace should have successfully ended up in ClickHouse, so you can read it selectively using McBackend API. Did I add kwargs to “to_inferencedata()” maybe?

Ah ok -

Yes, the trace is in ClickHouse so no issues there - there is a **kwargs in to_inferencedata() but it’s not clear (to this user) from the code how to access a particular variable or subset.

The to_inferencedata(**kwargs) get forwarded to arviz.from_dict, so they won’t help here.

You can do this:

run = ch_backend.get_run(your_run_id)
chains = run.get_chains()

c0 = chains[0]
some_var = c0.get_draws("some_var")  # load all draws for this variable
some_var = c0.get_draws("some_var", slc=slice(200, None, 50)  # from 200 until end, with 50x thinning

I know this is more tedious than InferenceData, but this is both fast & memory efficient.

In principle, one can implement an McBackend backend to store using xarray/HDF5, which would then allow for partial read operations from the local filesystem.
I’ll include this in the GSoC 2024 project description which I’m currently updating here: GSoC 2024 projects · pymc-devs/pymc Wiki · GitHub

Another idea would be to add slc: slice and var_names: Collection[str] kwargs to to_inferencedata() to facilitate loading only certain parts into the InferenceData structure.
This is easier to implement than the xarray/HDF5 stuff. (And I’ll include it in the project description too.)

Ok!

Thanks very much for the help - and those enhancements look ideal as the xarray goodness is a major PyMC advantage.

I’ve hacked a quick and dirty solution by adding var_names=None to the to_inferencedata in core.py - seems to work so far!

def to_inferencedata(self, *, var_names=None, equalize_chain_lengths: bool = True, **kwargs) -> InferenceData:
        """Creates an ArviZ ``InferenceData`` object from this run.

        Parameters
        ----------
        equalize_chain_lengths : bool
            Whether to truncate all chains to the shortest chain length (default: ``True``).
        **kwargs
            Will be forwarded to ``arviz.from_dict()``.

        Returns
        -------
        idata : arviz.InferenceData
            Samples and metadata of this inference run.
        """
        if not _HAS_ARVIZ:
            raise ModuleNotFoundError("ArviZ is not installed.")

        var_list = [n.name for n in self.meta.variables]
        
        if var_names:
            variables = list(numpy.array(self.meta.variables)[numpy.array([var_list.index(v) for v in var_names])])
        else:
            variables = self.meta.variables
        chains = self.get_chains()

        nonrigid_vars = {var for var in variables if var.undefined_ndim or not is_rigid(var.shape)}
        if nonrigid_vars:
            raise NotImplementedError(
                "Creating InferenceData from runs with non-rigid variables is not supported."
                f" The non-rigid variables are: {nonrigid_vars}."
            )

        chain_lengths = {c.cid: len(c) for c in chains}
        if len(set(chain_lengths.values())) != 1:
            msg = f"Chains vary in length. Lenghts are: {chain_lengths}"
            if not equalize_chain_lengths:
                msg += (
                    "\nArviZ does not properly support uneven chain lengths (see ArviZ issue #2094)."
                    "\nWe'll try to give you an InferenceData, but best case the chain & draw dimensions"
                    " will be messed-up as {'chain': 1, 'draws': n_chains}."
                    "\nYou won't be able to save this InferenceData to a file"
                    " and you should expect many ArviZ functions to choke on it."
                    "\nSpecify `to_inferencedata(equalize_chain_lengths=True)` to get regular InferenceData."
                )
            else:
                msg += "\nTruncating to the length of the shortest chain."
            _log.warning(msg)
        min_clen = None
        if equalize_chain_lengths:
            # A minimum chain length is introduced so that all chains have equal length
            min_clen = min(chain_lengths.values())
        # Aggregate draws and stats, while splitting into warmup/posterior
        warmup_posterior = collections.defaultdict(list)
        warmup_sample_stats = collections.defaultdict(list)
        posterior = collections.defaultdict(list)
        sample_stats = collections.defaultdict(list)
        for c, chain in enumerate(chains):
            # Create a slice to use when fetching the variables
            if min_clen is None:
                # Every retrieved array is shortened to the previously determined chain length.
                # Needed for backends which may get inserts inbetween our get_draws/get_stats calls.
                slc = slice(0, chain_lengths[chain.cid])
            else:
                slc = slice(0, min_clen)

            # Obtain a mask by which draws can be split into warmup/posterior
            try:
                # Use the same slice to avoid shape issues in case the chain is still active
                tune = get_tune_mask(chain, slc)
            except KeyError:
                if c == 0:
                    _log.warning(
                        "No 'tune' stat found. Assuming all iterations are posterior draws."
                    )
                tune = numpy.full((slc.stop,), False)

            # Split all variables draws into warmup/posterior
            for var in variables:
                draws = chain.get_draws(var.name, slc)
                warmup_posterior[var.name].append(draws[tune])
                posterior[var.name].append(draws[~tune])
            # Same for sample stats
            for svar in self.meta.sample_stats:
                stats = chain.get_stats(svar.name, slc)
                warmup_sample_stats[svar.name].append(stats[tune])
                sample_stats[svar.name].append(stats[~tune])

        w_pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_posterior)
        w_ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_sample_stats)
        pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], posterior)
        ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], sample_stats)
        if not equalize_chain_lengths:
            # Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
            w_pst = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
            w_ss = {k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()}
            pst = {k: as_array_from_ragged(v) for k, v in posterior.items()}
            ss = {k: as_array_from_ragged(v) for k, v in sample_stats.items()}

        idata = from_dict(
            warmup_posterior=w_pst,
            warmup_sample_stats=w_ss,
            posterior=pst,
            sample_stats=ss,
            coords=self.coords,
            dims=self.dims,
            attrs=self.meta.attributes,
            constant_data=self.constant_data,
            observed_data=self.observed_data,
            save_warmup=True,
            **kwargs,
        )
        return idata

Do you want to open a pull request for that?

A test shouldn’t be too hard either.