Okay, after some tests here is what I am running:
import arviz
import numpy as np
import pymc as pm
import mcbackend as mcb
from clickhouse_driver import Client
def define_simple_model():
seconds = np.linspace(0, 5)
observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :])
with pm.Model(
coords={
"condition": ["A", "B", "C"],
}
) as pmodel:
x = pm.ConstantData("seconds", seconds, dims="time")
a = pm.Normal("scalar")
b = pm.Uniform("vector", dims="condition")
pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time"))
obs = pm.MutableData("obs", observations, dims=("condition", "time"))
pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time"))
return pmodel
if __name__=='__main__':
simple_model = define_simple_model()
backend = mcb.NumPyBackend()
with simple_model:
trace = pm.sample(
trace=backend
)
print(trace)
This raises the following error:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [scalar, vector]
Traceback (most recent call last):████████████████████████████████████████████████████████████| 100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 35, in <module>
trace = pm.sample(
^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 702, in sample
return _sample_return(
^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 736, in _sample_return
mtrace = MultiTrace(traces)[:length]
~~~~~~~~~~~~~~~~~~^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 370, in __getitem__
return self._slice(idx)
^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in _slice
new_traces = [trace._slice(slice) for trace in self._straces.values()]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in <listcomp>
new_traces = [trace._slice(slice) for trace in self._straces.values()]
^^^^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 194, in _slice
stats = self._chain.get_stats_at(i, stat_names=snames)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in get_stats_at
return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in <dictcomp>
return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
~~~~~~~~~~~~~~~^^^^^
IndexError: index 1000 is out of bounds for axis 0 with size 1000
Ultimately, I am interested in running with the ClickHouse backend. But when I replace the relevant part:
if __name__=='__main__':
simple_model = define_simple_model()
ch_client = Client("localhost")
# Check that it is defined properly
print(ch_client.execute('SHOW DATABASES'))
backend = mcb.ClickHouseBackend(ch_client)
with simple_model:
trace = pm.sample(
trace=backend
)
print(trace)
another error is raised:
[('INFORMATION_SCHEMA',), ('default',), ('information_schema',), ('system',)]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Traceback (most recent call last):
File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 37, in <module>
trace = pm.sample(
^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 623, in sample
run, traces = init_traces(
^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/__init__.py", line 127, in init_traces
return init_chain_adapters(
^^^^^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 278, in init_chain_adapters
adapters = [
^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 280, in <listcomp>
chain=run.init_chain(chain_number=chain_number),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 325, in init_chain
create_chain_table(self._client, cmeta, self.meta)
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 98, in create_chain_table
columns.append(column_spec_for(var, is_stat=True))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 65, in column_spec_for
raise KeyError(
KeyError: "Don't know how to store dtype object of 'sampler_0__warning' (is_stat=True) in ClickHouse."
I am using PyMC v.5.3.0. Any help appreciated!
EDIT:
I fixed the error by adding the following lines in “mcbackend/backends/clickhouse.py” on line 101:
if var.dtype=='object':
var.dtype='str'
As it turns out, this only affects the stats “sampler_0__warning”:
Variable(name='tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__depth', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__step_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__mean_tree_accept', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__step_size_bar', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__tree_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__diverging', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__energy', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__max_energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__model_logp', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__process_time_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__perf_counter_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__perf_counter_start', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__largest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__smallest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__index_in_trajectory', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__reached_max_treedepth', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True
Variable(name='sampler_0__warning', dtype='object', shape=[], dims=[], is_deterministic=False, undefined_ndim=True)
False
I ran the code above a few time with the ClickHouse backend (with the fix to mcbackend). Then I ran this code:
ch_client = Client("localhost")
backend = mcbackend.ClickHouseBackend(ch_client)
print("list of runs:\n", backend.get_runs())
# Fetch a single run from the database (downloads just metadata)
run = backend.get_run("6F6PW")
# Convert everything to `InferenceData`
idata = run.to_inferencedata()
print(idata)
Which, as expected, prints out the following:
list of runs:
created_at
rid
99PHY 2023-05-19 02:54:46.005139+00:00 \
TREHW 2023-05-19 02:56:15.987728+00:00
1PQK6 2023-05-19 02:56:56.700751+00:00
YKT3N 2023-05-19 03:34:37.115745+00:00
WLKW4 2023-05-19 03:43:04.277188+00:00
9HHAC 2023-05-19 03:49:53.529272+00:00
1P93X 2023-05-19 03:51:50.767925+00:00
Q8LLA 2023-05-19 03:54:46.947725+00:00
DPP94 2023-05-19 03:56:24.189411+00:00
9CFX4 2023-05-19 03:58:19.232365+00:00
C3AZB 2023-05-19 04:01:53.742057+00:00
A36Y7 2023-05-19 04:05:15.148334+00:00
M4LND 2023-05-19 04:06:23.221856+00:00
6F6PW 2023-05-19 04:44:46.898726+00:00
proto
rid
99PHY RunMeta(rid='99PHY', variables=[Variable(name=...
TREHW RunMeta(rid='TREHW', variables=[Variable(name=...
1PQK6 RunMeta(rid='1PQK6', variables=[Variable(name=...
YKT3N RunMeta(rid='YKT3N', variables=[Variable(name=...
WLKW4 RunMeta(rid='WLKW4', variables=[Variable(name=...
9HHAC RunMeta(rid='9HHAC', variables=[Variable(name=...
1P93X RunMeta(rid='1P93X', variables=[Variable(name=...
Q8LLA RunMeta(rid='Q8LLA', variables=[Variable(name=...
DPP94 RunMeta(rid='DPP94', variables=[Variable(name=...
9CFX4 RunMeta(rid='9CFX4', variables=[Variable(name=...
C3AZB RunMeta(rid='C3AZB', variables=[Variable(name=...
A36Y7 RunMeta(rid='A36Y7', variables=[Variable(name=...
M4LND RunMeta(rid='M4LND', variables=[Variable(name=...
6F6PW RunMeta(rid='6F6PW', variables=[Variable(name=...
Inference data with groups:
> posterior
> sample_stats
> observed_data
> constant_data
Warmup iterations saved (warmup_*).
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
warnings.warn(
This still has a problem, namely the “warmup_sample_stats” group of the idata contains a bunch of fields, but they are all empty.