Using mcbackend to store samples from Blackjax sampler

Hi all,

I came across one of the blogpost here (Using mcbackend to store samples) that were solving mid-sample memory crashes by storing their samples in Clickhouse using mcbackend.

I have tested so far with the default pymc sampler which works perfectly fine. However, with Blackjax sampler, I’m unable to store the samples within Clickhouse.

The sampling code:

# Establish connection to Clickhouse
ch_client = Client.from_url(ch_connection_string)
ch_backend = mcbackend.ClickHouseBackend(ch_client)

# Check if backend properly loaded
assert isinstance(ch_backend, mcbackend.Backend) == True

trace = pm.sample(
    draws=draws,
    tune=tune,
    chains=chains,
    model=model,
    trace=ch_backend,
    compile_kwargs=dict(mode="NUMBA"),
    return_inferencedata=False,
    compute_convergence_checks=False,
    progressbar=False,
    nuts_sampler="blackjax",
    nuts_sampler_kwargs={
        "postprocessing_backend": "cpu",
        "chain_method": "vectorized",
    },
)

One of the reason why we want to utilize Blackjax is the speed.

Is there any misconfiguration here? Or mcbackend does not support other NUTS sampler? Thanks!

Tagging @michaelosthege in case this is a limitation with mcbackend.

Blackjax/numpyro (the way pymc calls it) only return samples after all sampling is done so you won’t avoid memory problems by trying to save them to mcbackend.

Does running the PyMC sampler with compile_kwargs=(mode=”JAX”) give you any speedup?

You may need to change mp_ctx to something compatible with JAX

Hi Ricardo,

Thank you for the reply! I will need some more time to write some sort of multiprocessing in order for Jax and Clickhouse to work together since Clickhouse need to have connection separate in each process (I believe so!). I will pause this for bit due to time constraint on my side but if anyone can figure out, feel free to share the solution here :slight_smile: