Problem to load trained BART model for prediction

Hey, :slight_smile:
I’m currently trying out PyMC-BART for a classification model (as alternative to a logistic regression model). It looks very promising (!) - but currently I’m facing some problems to integrate it in my usual workflow.

If I run model training and prediction in one script, everything works fine:

# define model
with pm.Model() as model:
    model.add_coord('id', df_train.index, mutable=True)
    model.add_coord('feature', df_train.drop(columns='Class').columns, mutable=True)

    data_x  = pm.MutableData('data_x', df_train.drop(columns='Class'), dims=('id', 'feature'))
    data_class = df_train['Class']

    k1 = pmb.BART('k1', data_x, data_class, m=25, dims='id')
    p1 = pm.Deterministic('p1', pm.invlogit(k1), dims='id')
    p0 = pm.Deterministic('p0', 1 - p1, dims='id')
    
    pm.Bernoulli('class', p=p1, observed=data_class, dims='id')

# train
with model:
    idata = pm.sample_prior_predictive()
    idata.extend(pm.sample(chains=4))
idata.to_netcdf(FILE_IDATA)

# predict
with model:
    pm.set_data({
        'data_x': df_test.drop(columns=['Class'], errors='ignore'),
    }, coords={'id': df_test.index})
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)

However, usually I train the model only once and save resulting InferenceData object for later predictions, to avoid frequent re-training. For other models this work fine - but for this BART model I’m getting the following error:

# define model (same as above)
with pm.Model() as model:
    model.add_coord('id', df_train.index, mutable=True)
    model.add_coord('feature', df_train.drop(columns='Class').columns, mutable=True)

    data_x  = pm.MutableData('data_x', df_train.drop(columns='Class'), dims=('id', 'feature'))
    data_class = df_train['Class']

    k1 = pmb.BART('k1', data_x, data_class, m=25, dims='id')
    p1 = pm.Deterministic('p1', pm.invlogit(k1), dims='id')
    p0 = pm.Deterministic('p0', 1 - p1, dims='id')
    
    pm.Bernoulli('class', p=p1, observed=data_class, dims='id')

# load
idata = az.from_netcdf(FILE_IDATA)

# predict
with model:
    pm.set_data({
        'data_x': df_test.drop(columns=['Class'], errors='ignore'),
    }, coords={'id': df_test.index})
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:

File /opt/conda/lib/python3.10/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    539 @is_thunk_type
    540 def rval(
    541     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    542 ):
--> 543     r = p(n, [x[0] for x in i], o)
    544     for o in node.outputs:

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/op.py:378, in RandomVariable.perform(self, node, inputs, outputs)
    376 rng_var_out[0] = rng
--> 378 smpl_val = self.rng_fn(rng, *(args + [size]))
    380 if (
    381     not isinstance(smpl_val, np.ndarray)
    382     or str(smpl_val.dtype) != out_var.type.dtype
    383 ):

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:55, in ScipyRandomVariable.rng_fn(cls, *args, **kwargs)
     54 size = args[-1]
---> 55 res = cls.rng_fn_scipy(*args, **kwargs)
     57 if np.ndim(res) == 0:
     58     # The sample is an `np.number`, and is not writeable, or non-NumPy
     59     # type, so we need to clone/create a usable NumPy result

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:1473, in BernoulliRV.rng_fn_scipy(cls, rng, p, size)
   1471 @classmethod
   1472 def rng_fn_scipy(cls, rng, p, size):
-> 1473     return stats.bernoulli.rvs(p, size=size, random_state=rng)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:3357, in rv_discrete.rvs(self, *args, **kwargs)
   3356 kwargs['discrete'] = True
-> 3357 return super().rvs(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:1028, in rv_generic.rvs(self, *args, **kwds)
   1027 rndm = kwds.pop('random_state', None)
-> 1028 args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
   1029 cond = logical_and(self._argcheck(*args), (scale >= 0))

File <string>:6, in _parse_args_rvs(self, p, loc, size)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:909, in rv_generic._argcheck_rvs(self, *args, **kwargs)
    908 if not ok:
--> 909     raise ValueError("size does not match the broadcast shape of "
    910                      "the parameters. %s, %s, %s" % (size, size_,
    911                                                      bcast_shape))
    913 param_bcast = all_bcast[:-2]

ValueError: size does not match the broadcast shape of the parameters. (5,), (5,), (617,)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[11], line 5
      1 with model:
      2     pm.set_data({
      3         'data_x': df_test.drop(columns=['Class'], errors='ignore'),
      4     }, coords={'id': df_test.index})
----> 5     idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)

File /opt/conda/lib/python3.10/site-packages/pymc/sampling/forward.py:644, in sample_posterior_predictive(trace, model, var_names, sample_dims, random_seed, progressbar, return_inferencedata, extend_inferencedata, predictions, idata_kwargs, compile_kwargs)
    639 # there's only a single chain, but the index might hit it multiple times if
    640 # the number of indices is greater than the length of the trace.
    641 else:
    642     param = _trace[idx % len_trace]
--> 644 values = sampler_fn(**param)
    646 for k, v in zip(vars_, values):
    647     ppc_trace_t.insert(k.name, v, idx)

File /opt/conda/lib/python3.10/site-packages/pymc/util.py:389, in point_wrapper.<locals>.wrapped(**kwargs)
    387 def wrapped(**kwargs):
    388     input_point = {k: v for k, v in kwargs.items() if k in ins}
--> 389     return core_function(**input_point)

File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:983, in Function.__call__(self, *args, **kwargs)
    981     if hasattr(self.vm, "thunks"):
    982         thunk = self.vm.thunks[self.vm.position_of_error]
--> 983     raise_with_op(
    984         self.maker.fgraph,
    985         node=self.vm.nodes[self.vm.position_of_error],
    986         thunk=thunk,
    987         storage_map=getattr(self.vm, "storage_map", None),
    988     )
    989 else:
    990     # old-style linkers raise their own exceptions
    991     raise

File /opt/conda/lib/python3.10/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File /opt/conda/lib/python3.10/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    539 @is_thunk_type
    540 def rval(
    541     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    542 ):
--> 543     r = p(n, [x[0] for x in i], o)
    544     for o in node.outputs:
    545         compute_map[o][0] = True

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/op.py:378, in RandomVariable.perform(self, node, inputs, outputs)
    374     rng = copy(rng)
    376 rng_var_out[0] = rng
--> 378 smpl_val = self.rng_fn(rng, *(args + [size]))
    380 if (
    381     not isinstance(smpl_val, np.ndarray)
    382     or str(smpl_val.dtype) != out_var.type.dtype
    383 ):
    384     smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:55, in ScipyRandomVariable.rng_fn(cls, *args, **kwargs)
     52 @classmethod
     53 def rng_fn(cls, *args, **kwargs):
     54     size = args[-1]
---> 55     res = cls.rng_fn_scipy(*args, **kwargs)
     57     if np.ndim(res) == 0:
     58         # The sample is an `np.number`, and is not writeable, or non-NumPy
     59         # type, so we need to clone/create a usable NumPy result
     60         res = np.asarray(res)

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:1473, in BernoulliRV.rng_fn_scipy(cls, rng, p, size)
   1471 @classmethod
   1472 def rng_fn_scipy(cls, rng, p, size):
-> 1473     return stats.bernoulli.rvs(p, size=size, random_state=rng)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:3357, in rv_discrete.rvs(self, *args, **kwargs)
   3328 """Random variates of given type.
   3329 
   3330 Parameters
   (...)
   3354 
   3355 """
   3356 kwargs['discrete'] = True
-> 3357 return super().rvs(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:1028, in rv_generic.rvs(self, *args, **kwds)
   1026 discrete = kwds.pop('discrete', None)
   1027 rndm = kwds.pop('random_state', None)
-> 1028 args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
   1029 cond = logical_and(self._argcheck(*args), (scale >= 0))
   1030 if not np.all(cond):

File <string>:6, in _parse_args_rvs(self, p, loc, size)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:909, in rv_generic._argcheck_rvs(self, *args, **kwargs)
    906 ok = all([bcdim == 1 or bcdim == szdim
    907           for (bcdim, szdim) in zip(bcast_shape, size_)])
    908 if not ok:
--> 909     raise ValueError("size does not match the broadcast shape of "
    910                      "the parameters. %s, %s, %s" % (size, size_,
    911                                                      bcast_shape))
    913 param_bcast = all_bcast[:-2]
    914 loc_bcast = all_bcast[-2]

ValueError: size does not match the broadcast shape of the parameters. (5,), (5,), (617,)
Apply node that caused the error: bernoulli_rv{0, (0,), int64, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7D8D919E2C00>), MakeVector{dtype='int64'}.0, TensorConstant{4}, p1)
Toposort index: 3
Inputs types: [RandomGeneratorType, TensorType(int64, (1,)), TensorType(int64, ()), TensorType(float64, (?,))]
Inputs shapes: ['No shapes', (1,), (), (617,)]
Inputs strides: ['No strides', (8,), (), (8,)]
Inputs values: [Generator(PCG64) at 0x7D8D919E2C00, array([5]), array(4), 'not shown']
Outputs clients: [['output'], ['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

It seems to me that the re-instantiation of the model object causes this problem - but I have no idea why…

Any help would be very appreciated! :slight_smile:

Does someone has an idea how I could save idata from a BART model to use it later for prediction?

CC @aloctavodia (Sorry for pinging you - but in other threads it seems like you have most knowledge about the BART implementation :wink:)

The first issue is that you need to do

pm.Bernoulli('class', p=p1, observed=data_class, dims='id', shape=data_x.shape)

and then you have another issue saving the idata, right?

@aloctavodia Thanks for your fast reply! :slight_smile:

The problem is not so much about saving idata, but to load it later to run predictions. If I train the model and do predictions in one notebook, everything works fine. But if I train the model and save the idata in one notebook, and later in another notebook I load the idata to run predictions, I’m getting the error that I’ve posted above.

I’ve just tried to add shape=data_x.shape but this just gave me an error while building the model:

---------------------------------------------------------------------------
ShapeError                                Traceback (most recent call last)
Cell In[35], line 1
----> 1 model = build_model()
      2 pm.model_to_graphviz(model)

Cell In[34], line 14, in build_model()
     11     p1 = pm.Deterministic('p1', pm.invlogit(k1), dims='id')
     12     p0 = pm.Deterministic('p0', 1 - p1, dims='id')
---> 14     pm.Bernoulli('class', p=p1, observed=data_class, dims='id', shape=data_x.shape)
     15 return model

File /opt/conda/lib/python3.10/site-packages/pymc/distributions/distribution.py:460, in Discrete.__new__(cls, name, *args, **kwargs)
    457 if kwargs.get("transform", None):
    458     raise ValueError("Transformations for discrete distributions")
--> 460 return super().__new__(cls, name, *args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/pymc/distributions/distribution.py:314, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    310         kwargs["shape"] = tuple(observed.shape)
    312 rv_out = cls.dist(*args, **kwargs)
--> 314 rv_out = model.register_rv(
    315     rv_out,
    316     name,
    317     observed,
    318     total_size,
    319     dims=dims,
    320     transform=transform,
    321     initval=initval,
    322 )
    324 # add in pretty-printing support
    325 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1356, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
   1349         raise TypeError(
   1350             "Variables that depend on other nodes cannot be used for observed data."
   1351             f"The data variable was: {observed}"
   1352         )
   1354     # `rv_var` is potentially changed by `make_obs_var`,
   1355     # for example into a new graph for imputation of missing data.
-> 1356     rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
   1358 return rv_var

File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1387, in Model.make_obs_var(self, rv_var, data, dims, transform, total_size)
   1384 data = convert_observed_data(data).astype(rv_var.dtype)
   1386 if data.ndim != rv_var.ndim:
-> 1387     raise ShapeError(
   1388         "Dimensionality of data and RV don't match.", actual=data.ndim, expected=rv_var.ndim
   1389     )
   1391 if pytensor.config.compute_test_value != "off":
   1392     test_value = getattr(rv_var.tag, "test_value", None)

ShapeError: Dimensionality of data and RV don't match. (actual 1 != expected 2)

I also tried to add shape=data_class.shape instead which allowed me model building and sampling, but during prediction I’m getting then an error:

ValueError                                Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:

File /opt/conda/lib/python3.10/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    539 @is_thunk_type
    540 def rval(
    541     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    542 ):
--> 543     r = p(n, [x[0] for x in i], o)
    544     for o in node.outputs:

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/op.py:378, in RandomVariable.perform(self, node, inputs, outputs)
    376 rng_var_out[0] = rng
--> 378 smpl_val = self.rng_fn(rng, *(args + [size]))
    380 if (
    381     not isinstance(smpl_val, np.ndarray)
    382     or str(smpl_val.dtype) != out_var.type.dtype
    383 ):

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:55, in ScipyRandomVariable.rng_fn(cls, *args, **kwargs)
     54 size = args[-1]
---> 55 res = cls.rng_fn_scipy(*args, **kwargs)
     57 if np.ndim(res) == 0:
     58     # The sample is an `np.number`, and is not writeable, or non-NumPy
     59     # type, so we need to clone/create a usable NumPy result

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:1473, in BernoulliRV.rng_fn_scipy(cls, rng, p, size)
   1471 @classmethod
   1472 def rng_fn_scipy(cls, rng, p, size):
-> 1473     return stats.bernoulli.rvs(p, size=size, random_state=rng)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:3357, in rv_discrete.rvs(self, *args, **kwargs)
   3356 kwargs['discrete'] = True
-> 3357 return super().rvs(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:1028, in rv_generic.rvs(self, *args, **kwds)
   1027 rndm = kwds.pop('random_state', None)
-> 1028 args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
   1029 cond = logical_and(self._argcheck(*args), (scale >= 0))

File <string>:6, in _parse_args_rvs(self, p, loc, size)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:909, in rv_generic._argcheck_rvs(self, *args, **kwargs)
    908 if not ok:
--> 909     raise ValueError("size does not match the broadcast shape of "
    910                      "the parameters. %s, %s, %s" % (size, size_,
    911                                                      bcast_shape))
    913 param_bcast = all_bcast[:-2]

ValueError: size does not match the broadcast shape of the parameters. (617,), (617,), (5,)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[12], line 5
      1 with model:
      2     pm.set_data({
      3         'data_x': df_test.drop(columns=['Class'], errors='ignore'),
      4     }, coords={'id': df_test.index})
----> 5     idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)
      6 idata.to_netcdf(FILE_IDATA)
      7 display(idata)

File /opt/conda/lib/python3.10/site-packages/pymc/sampling/forward.py:644, in sample_posterior_predictive(trace, model, var_names, sample_dims, random_seed, progressbar, return_inferencedata, extend_inferencedata, predictions, idata_kwargs, compile_kwargs)
    639 # there's only a single chain, but the index might hit it multiple times if
    640 # the number of indices is greater than the length of the trace.
    641 else:
    642     param = _trace[idx % len_trace]
--> 644 values = sampler_fn(**param)
    646 for k, v in zip(vars_, values):
    647     ppc_trace_t.insert(k.name, v, idx)

File /opt/conda/lib/python3.10/site-packages/pymc/util.py:389, in point_wrapper.<locals>.wrapped(**kwargs)
    387 def wrapped(**kwargs):
    388     input_point = {k: v for k, v in kwargs.items() if k in ins}
--> 389     return core_function(**input_point)

File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:983, in Function.__call__(self, *args, **kwargs)
    981     if hasattr(self.vm, "thunks"):
    982         thunk = self.vm.thunks[self.vm.position_of_error]
--> 983     raise_with_op(
    984         self.maker.fgraph,
    985         node=self.vm.nodes[self.vm.position_of_error],
    986         thunk=thunk,
    987         storage_map=getattr(self.vm, "storage_map", None),
    988     )
    989 else:
    990     # old-style linkers raise their own exceptions
    991     raise

File /opt/conda/lib/python3.10/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File /opt/conda/lib/python3.10/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File /opt/conda/lib/python3.10/site-packages/pytensor/graph/op.py:543, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    539 @is_thunk_type
    540 def rval(
    541     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    542 ):
--> 543     r = p(n, [x[0] for x in i], o)
    544     for o in node.outputs:
    545         compute_map[o][0] = True

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/op.py:378, in RandomVariable.perform(self, node, inputs, outputs)
    374     rng = copy(rng)
    376 rng_var_out[0] = rng
--> 378 smpl_val = self.rng_fn(rng, *(args + [size]))
    380 if (
    381     not isinstance(smpl_val, np.ndarray)
    382     or str(smpl_val.dtype) != out_var.type.dtype
    383 ):
    384     smpl_val = _asarray(smpl_val, dtype=out_var.type.dtype)

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:55, in ScipyRandomVariable.rng_fn(cls, *args, **kwargs)
     52 @classmethod
     53 def rng_fn(cls, *args, **kwargs):
     54     size = args[-1]
---> 55     res = cls.rng_fn_scipy(*args, **kwargs)
     57     if np.ndim(res) == 0:
     58         # The sample is an `np.number`, and is not writeable, or non-NumPy
     59         # type, so we need to clone/create a usable NumPy result
     60         res = np.asarray(res)

File /opt/conda/lib/python3.10/site-packages/pytensor/tensor/random/basic.py:1473, in BernoulliRV.rng_fn_scipy(cls, rng, p, size)
   1471 @classmethod
   1472 def rng_fn_scipy(cls, rng, p, size):
-> 1473     return stats.bernoulli.rvs(p, size=size, random_state=rng)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:3357, in rv_discrete.rvs(self, *args, **kwargs)
   3328 """Random variates of given type.
   3329 
   3330 Parameters
   (...)
   3354 
   3355 """
   3356 kwargs['discrete'] = True
-> 3357 return super().rvs(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:1028, in rv_generic.rvs(self, *args, **kwds)
   1026 discrete = kwds.pop('discrete', None)
   1027 rndm = kwds.pop('random_state', None)
-> 1028 args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
   1029 cond = logical_and(self._argcheck(*args), (scale >= 0))
   1030 if not np.all(cond):

File <string>:6, in _parse_args_rvs(self, p, loc, size)

File /opt/conda/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:909, in rv_generic._argcheck_rvs(self, *args, **kwargs)
    906 ok = all([bcdim == 1 or bcdim == szdim
    907           for (bcdim, szdim) in zip(bcast_shape, size_)])
    908 if not ok:
--> 909     raise ValueError("size does not match the broadcast shape of "
    910                      "the parameters. %s, %s, %s" % (size, size_,
    911                                                      bcast_shape))
    913 param_bcast = all_bcast[:-2]
    914 loc_bcast = all_bcast[-2]

ValueError: size does not match the broadcast shape of the parameters. (617,), (617,), (5,)
Apply node that caused the error: bernoulli_rv{0, (0,), int64, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x79F4A8C210E0>), TensorConstant{(1,) of 617}, TensorConstant{4}, p1)
Toposort index: 2
Inputs types: [RandomGeneratorType, TensorType(int64, (1,)), TensorType(int64, ()), TensorType(float64, (?,))]
Inputs shapes: ['No shapes', (1,), (), (5,)]
Inputs strides: ['No strides', (8,), (), (8,)]
Inputs values: [Generator(PCG64) at 0x79F4A8C210E0, array([617]), array(4), array([0.35075572, 0.35075572, 0.35075572, 0.35075572, 0.35075572])]
Outputs clients: [['output'], ['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Just a short ping on this issue - since I wanted to use BART also for another project and still run into the same problem.

Is there any possibility to support the following workflow with a BART model?

  1. train.ipynb
    • Instantiate PyMC model with df_train data
    • Sample from posterior via pm.sample()
    • Save returned idata to file
  2. predict.ipynb
    • Instantiate same PyMC model with df_predict data
    • Load previous idata from file
    • Sample from posterior predictive via pm.sample_posterior_predictive(idata, predictions=True)
    • (Do further analysis with the prediction…)

This workflow is quite efficient for other PyMC models, since I don’t need to re-call pm.sample() in the predict.ipynb notebook - but it fails with BART models (see posts above).

I have the feeling that the information stored in idata.posterior is not sufficient to recreate the BART posterior distribution in predict.ipynb - but honestly, I’m not an expert with this type of models.

Hi krum_sv,

I was trying to implement a similar workflow and figured out a workaround. The link is to the issue I submitted in github and at the bottom of the page I have added my workaround,
[[BART model save, reload and new predictions · Issue #123 · pymc-devs/pymc-bart · GitHub]]

From what I figured out the idata doesn’t at the moment contain the tree object needed for reloading/sampling the model. However, I found I can save the tree object (as a pkl) and then read that back in. Then a utility function (_sample_posterior()) from pmb.utilities can be used to return sample predictions from the tree given a dataset.

This worked when I tested it and gave results similar to if using the sample_posterior_predictive on a model trained in the same instance.

I wasn’t able to figure out how to get the tree object back into a newly instantiated model, so it isn’t perfect, but it should work for some projects.

1 Like