Hey,
I’m currently trying out PyMC-BART for a classification model (as alternative to a logistic regression model). It looks very promising (!) - but currently I’m facing some problems to integrate it in my usual workflow.
If I run model training and prediction in one script, everything works fine:
# define model
with pm.Model() as model:
model.add_coord('id', df_train.index, mutable=True)
model.add_coord('feature', df_train.drop(columns='Class').columns, mutable=True)
data_x = pm.MutableData('data_x', df_train.drop(columns='Class'), dims=('id', 'feature'))
data_class = df_train['Class']
k1 = pmb.BART('k1', data_x, data_class, m=25, dims='id')
p1 = pm.Deterministic('p1', pm.invlogit(k1), dims='id')
p0 = pm.Deterministic('p0', 1 - p1, dims='id')
pm.Bernoulli('class', p=p1, observed=data_class, dims='id')
# train
with model:
idata = pm.sample_prior_predictive()
idata.extend(pm.sample(chains=4))
idata.to_netcdf(FILE_IDATA)
# predict
with model:
pm.set_data({
'data_x': df_test.drop(columns=['Class'], errors='ignore'),
}, coords={'id': df_test.index})
idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)
However, usually I train the model only once and save resulting InferenceData
object for later predictions, to avoid frequent re-training. For other models this work fine - but for this BART model I’m getting the following error:
# define model (same as above)
with pm.Model() as model:
model.add_coord('id', df_train.index, mutable=True)
model.add_coord('feature', df_train.drop(columns='Class').columns, mutable=True)
data_x = pm.MutableData('data_x', df_train.drop(columns='Class'), dims=('id', 'feature'))
data_class = df_train['Class']
k1 = pmb.BART('k1', data_x, data_class, m=25, dims='id')
p1 = pm.Deterministic('p1', pm.invlogit(k1), dims='id')
p0 = pm.Deterministic('p0', 1 - p1, dims='id')
pm.Bernoulli('class', p=p1, observed=data_class, dims='id')
# load
idata = az.from_netcdf(FILE_IDATA)
# predict
with model:
pm.set_data({
'data_x': df_test.drop(columns=['Class'], errors='ignore'),
}, coords={'id': df_test.index})
idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
File /opt/conda/lib/python3.10/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
539 @is_thunk_type
540 def rval(
541 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
542 ):
--> 543 r = p(n, [x[0] for x in i], o)
544 for o in node.outputs:
File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/op.py:378, in RandomVariable.perform(self, node, inputs, outputs)
376 rng_var_out[0] = rng
--> 378 smpl_val = self.rng_fn(rng, *(args + [size]))
380 if (
381 not isinstance(smpl_val, np.ndarray)
382 or str(smpl_val.dtype) != out_var.type.dtype
383 ):
File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:55, in ScipyRandomVariable.rng_fn(cls, *args, **kwargs)
54 size = args[-1]
---> 55 res = cls.rng_fn_scipy(*args, **kwargs)
57 if np.ndim(res) == 0:
58 # The sample is an `np.number`, and is not writeable, or non-NumPy
59 # type, so we need to clone/create a usable NumPy result
File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:1473, in BernoulliRV.rng_fn_scipy(cls, rng, p, size)
1471 @classmethod
1472 def rng_fn_scipy(cls, rng, p, size):
-> 1473 return stats.bernoulli.rvs(p, size=size, random_state=rng)
File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:3357, in rv_discrete.rvs(self, *args, **kwargs)
3356 kwargs['discrete'] = True
-> 3357 return super().rvs(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:1028, in rv_generic.rvs(self, *args, **kwds)
1027 rndm = kwds.pop('random_state', None)
-> 1028 args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
1029 cond = logical_and(self._argcheck(*args), (scale >= 0))
File <string>:6, in _parse_args_rvs(self, p, loc, size)
File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:909, in rv_generic._argcheck_rvs(self, *args, **kwargs)
908 if not ok:
--> 909 raise ValueError("size does not match the broadcast shape of "
910 "the parameters. %s, %s, %s" % (size, size_,
911 bcast_shape))
913 param_bcast = all_bcast[:-2]
ValueError: size does not match the broadcast shape of the parameters. (5,), (5,), (617,)
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[11], line 5
1 with model:
2 pm.set_data({
3 'data_x': df_test.drop(columns=['Class'], errors='ignore'),
4 }, coords={'id': df_test.index})
----> 5 idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)
File /opt/conda/lib/python3.10/site-packages/pymc/sampling/forward.py:644, in sample_posterior_predictive(trace, model, var_names, sample_dims, random_seed, progressbar, return_inferencedata, extend_inferencedata, predictions, idata_kwargs, compile_kwargs)
639 # there's only a single chain, but the index might hit it multiple times if
640 # the number of indices is greater than the length of the trace.
641 else:
642 param = _trace[idx % len_trace]
--> 644 values = sampler_fn(**param)
646 for k, v in zip(vars_, values):
647 ppc_trace_t.insert(k.name, v, idx)
File /opt/conda/lib/python3.10/site-packages/pymc/util.py:389, in point_wrapper.<locals>.wrapped(**kwargs)
387 def wrapped(**kwargs):
388 input_point = {k: v for k, v in kwargs.items() if k in ins}
--> 389 return core_function(**input_point)
File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:983, in Function.__call__(self, *args, **kwargs)
981 if hasattr(self.vm, "thunks"):
982 thunk = self.vm.thunks[self.vm.position_of_error]
--> 983 raise_with_op(
984 self.maker.fgraph,
985 node=self.vm.nodes[self.vm.position_of_error],
986 thunk=thunk,
987 storage_map=getattr(self.vm, "storage_map", None),
988 )
989 else:
990 # old-style linkers raise their own exceptions
991 raise
File /opt/conda/lib/python3.10/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
530 warnings.warn(
531 f"{exc_type} error does not allow us to add an extra error message"
532 )
533 # Some exception need extra parameter in inputs. So forget the
534 # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)
File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File /opt/conda/lib/python3.10/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
539 @is_thunk_type
540 def rval(
541 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
542 ):
--> 543 r = p(n, [x[0] for x in i], o)
544 for o in node.outputs:
545 compute_map[o][0] = True
File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/op.py:378, in RandomVariable.perform(self, node, inputs, outputs)
374 rng = copy(rng)
376 rng_var_out[0] = rng
--> 378 smpl_val = self.rng_fn(rng, *(args + [size]))
380 if (
381 not isinstance(smpl_val, np.ndarray)
382 or str(smpl_val.dtype) != out_var.type.dtype
383 ):
384 smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)
File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:55, in ScipyRandomVariable.rng_fn(cls, *args, **kwargs)
52 @classmethod
53 def rng_fn(cls, *args, **kwargs):
54 size = args[-1]
---> 55 res = cls.rng_fn_scipy(*args, **kwargs)
57 if np.ndim(res) == 0:
58 # The sample is an `np.number`, and is not writeable, or non-NumPy
59 # type, so we need to clone/create a usable NumPy result
60 res = np.asarray(res)
File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:1473, in BernoulliRV.rng_fn_scipy(cls, rng, p, size)
1471 @classmethod
1472 def rng_fn_scipy(cls, rng, p, size):
-> 1473 return stats.bernoulli.rvs(p, size=size, random_state=rng)
File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:3357, in rv_discrete.rvs(self, *args, **kwargs)
3328 """Random variates of given type.
3329
3330 Parameters
(...)
3354
3355 """
3356 kwargs['discrete'] = True
-> 3357 return super().rvs(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:1028, in rv_generic.rvs(self, *args, **kwds)
1026 discrete = kwds.pop('discrete', None)
1027 rndm = kwds.pop('random_state', None)
-> 1028 args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
1029 cond = logical_and(self._argcheck(*args), (scale >= 0))
1030 if not np.all(cond):
File <string>:6, in _parse_args_rvs(self, p, loc, size)
File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:909, in rv_generic._argcheck_rvs(self, *args, **kwargs)
906 ok = all([bcdim == 1 or bcdim == szdim
907 for (bcdim, szdim) in zip(bcast_shape, size_)])
908 if not ok:
--> 909 raise ValueError("size does not match the broadcast shape of "
910 "the parameters. %s, %s, %s" % (size, size_,
911 bcast_shape))
913 param_bcast = all_bcast[:-2]
914 loc_bcast = all_bcast[-2]
ValueError: size does not match the broadcast shape of the parameters. (5,), (5,), (617,)
Apply node that caused the error: bernoulli_rv{0, (0,), int64, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7D8D919E2C00>), MakeVector{dtype='int64'}.0, TensorConstant{4}, p1)
Toposort index: 3
Inputs types: [RandomGeneratorType, TensorType(int64, (1,)), TensorType(int64, ()), TensorType(float64, (?,))]
Inputs shapes: ['No shapes', (1,), (), (617,)]
Inputs strides: ['No strides', (8,), (), (8,)]
Inputs values: [Generator(PCG64) at 0x7D8D919E2C00, array([5]), array(4), 'not shown']
Outputs clients: [['output'], ['output']]
HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
It seems to me that the re-instantiation of the model
object causes this problem - but I have no idea why…
Any help would be very appreciated!