Hi all,
I came across one of the blogpost here (Using mcbackend to store samples) that were solving mid-sample memory crashes by storing their samples in Clickhouse using mcbackend.
I have tested so far with the default pymc sampler which works perfectly fine. However, with Blackjax sampler, I’m unable to store the samples within Clickhouse.
The sampling code:
# Establish connection to Clickhouse
ch_client = Client.from_url(ch_connection_string)
ch_backend = mcbackend.ClickHouseBackend(ch_client)
# Check if backend properly loaded
assert isinstance(ch_backend, mcbackend.Backend) == True
trace = pm.sample(
draws=draws,
tune=tune,
chains=chains,
model=model,
trace=ch_backend,
compile_kwargs=dict(mode="NUMBA"),
return_inferencedata=False,
compute_convergence_checks=False,
progressbar=False,
nuts_sampler="blackjax",
nuts_sampler_kwargs={
"postprocessing_backend": "cpu",
"chain_method": "vectorized",
},
)
One of the reason why we want to utilize Blackjax is the speed.
Is there any misconfiguration here? Or mcbackend does not support other NUTS sampler? Thanks!
Tagging @michaelosthege in case this is a limitation with mcbackend.