Arviz's 'to_netcdf' having issues saving my trace object

Hello,

I’m trying to save a trace after a six hour model run. I see in the Arviz documentation to use `az.to_netcdf’ but when I try to use that I get the following error.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_3891/3643003175.py in <module>
----> 1 az.to_netcdf(trace, filename='prophet_trace')

/opt/conda/lib/python3.7/site-packages/arviz/data/io_netcdf.py in to_netcdf(data, filename, group, coords, dims)
     59     """
     60     inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
---> 61     file_name = inference_data.to_netcdf(filename)
     62     return file_name

/opt/conda/lib/python3.7/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    425                 if compress:
    426                     kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 427                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    428                 data.close()
    429                 mode = "a"

/opt/conda/lib/python3.7/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1910             unlimited_dims=unlimited_dims,
   1911             compute=compute,
-> 1912             invalid_netcdf=invalid_netcdf,
   1913         )
   1914 

/opt/conda/lib/python3.7/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1071         # to be parallelized with dask
   1072         dump_to_store(
-> 1073             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1074         )
   1075         if autoclose:

/opt/conda/lib/python3.7/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1117         variables, attrs = encoder(variables, attrs)
   1118 
-> 1119     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1120 
   1121 

/opt/conda/lib/python3.7/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    259             writer = ArrayWriter()
    260 
--> 261         variables, attributes = self.encode(variables, attributes)
    262 
    263         self.set_attributes(attributes)

/opt/conda/lib/python3.7/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
    348         # All NetCDF files get CF encoded by default, without this attempting
    349         # to write times, for example, would fail.
--> 350         variables, attributes = cf_encoder(variables, attributes)
    351         variables = {k: self.encode_variable(v) for k, v in variables.items()}
    352         attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}

/opt/conda/lib/python3.7/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
    853     _update_bounds_encoding(variables)
    854 
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856 
    857     # Remove attrs from bounds variables (issue #2921)

/opt/conda/lib/python3.7/site-packages/xarray/conventions.py in <dictcomp>(.0)
    853     _update_bounds_encoding(variables)
    854 
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856 
    857     # Remove attrs from bounds variables (issue #2921)

/opt/conda/lib/python3.7/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
    273     var = maybe_default_fill_value(var)
    274     var = maybe_encode_bools(var)
--> 275     var = ensure_dtype_not_object(var, name=name)
    276 
    277     for attr_name in CF_RELATED_DATA:

/opt/conda/lib/python3.7/site-packages/xarray/conventions.py in ensure_dtype_not_object(var, name)
    231             data[missing] = fill_value
    232         else:
--> 233             data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
    234 
    235         assert data.dtype.kind != "O" or data.dtype.metadata

/opt/conda/lib/python3.7/site-packages/xarray/conventions.py in _infer_dtype(array, name)
    167     raise ValueError(
    168         "unable to infer dtype on variable {!r}; xarray "
--> 169         "cannot serialize arbitrary Python objects".format(name)
    170     )
    171 

ValueError: unable to infer dtype on variable 'changepoints'; xarray cannot serialize arbitrary Python objects

The changepoints are a part of the dims so I’m not sure why it can’t be serialized. Has anyone had any luck saving a trace?

What is changepoints? Could they be a multiindex coordinate values?

xarray is more flexible than netcdf, so not all xarray objects can be saved as netcdf files straight away. In some cases you’ll need to downcast some complex types or objects so they become compatible with netcdf. You can also try using arviz.InferenceData.to_zarr — ArviZ dev documentation which is more flexible than netcdf, but keep in mind that both are file formats that can be read/written from many different languages (which is one of the main appeals of using them) so by design they can’t support all custom python types or dtypes

Thank you!

changepoints is just the integer 8. I just went back and tried a good ol’ fashion pickle which worked using ‘wb’ and ‘rb’. I re-loaded what I saved and it seemed to load correctly although I’m seeing some discourse threads where that didn’t seem to work.

can you share the output of

print(trace.posterior)  # or the group where changepoints is instead of posterior

of the loaded pickle object?

<xarray.Dataset>
Dimensions:                (chain: 4, draw: 1000, locations: 15, items: 217, changepoints: 8, yearly_components: 10, months: 12, obs_id: 35494)
Coordinates:
  * chain                  (chain) int64 0 1 2 3
  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
  * locations              (locations) <U5 'kr_01' 'kr_02' ... 'kr_49' 'kr_80'
  * items                  (items) <U8 '100179K' '100186K' ... 'SRV0012'
  * changepoints           (changepoints) object 2020-02 2020-03 ... 2020-10
  * yearly_components      (yearly_components) <U12 'yearly_cos_1' ... 'yearl...
  * months                 (months) int64 1 2 3 4 5 6 7 8 9 10 11 12
  * obs_id                 (obs_id) <U33 'kr_01_2020_month_1_item_100179K' .....
Data variables: (12/21)
    mu_slope               (chain, draw) float64 0.5443 0.5424 ... -0.2743
    offset_loc_slope       (chain, draw, locations) float64 -0.1935 ... -0.4139
    offset_item_slope      (chain, draw, items) float64 -0.5775 ... -0.09872
    mu_intercept           (chain, draw) float64 2.844 2.842 ... 0.1531 0.1237
    offset_loc_intercept   (chain, draw, locations) float64 0.5768 ... 0.006356
    offset_item_intercept  (chain, draw, items) float64 0.2861 0.5202 ... -1.245
    ...                     ...
    sigma_item_intercept   (chain, draw) float64 5.441 5.447 ... 5.247 5.24
    sigma_loc_delta        (chain, draw) float64 0.1056 0.1056 ... 0.006914
    sigma_item_delta       (chain, draw) float64 0.05344 0.05345 ... 0.1297
    yearly_sigma           (chain, draw) float64 0.5466 0.5465 ... 0.3494 0.3583
    yearly_seasonality     (chain, draw, obs_id) float64 0.5862 0.5862 ... 3.959
    mu                     (chain, draw, obs_id) float64 5.801 7.074 ... 5.739
Attributes:
    created_at:     2022-08-05T04:08:39.854869
    arviz_version:  0.12.1

I don’t think the pre-post pickle matters here.

You can see that the dtype of changepoints is not integer but object, hence the error “Unable to infer dtype on variable changepoints”. To be able to serialize it correctly, you should convert it to string or to datetime. That will be something like:

post = idata.posterior

# for string
post["changepoints"] = post["changepoints"].astype(str)

# for datetime dtype
post["changepoints"] = post["changepoints"].astype("datetime64[ns]")