Hi,
I am trying to merge multiple traces in a single chain and then merge again many of these merged single chains across the chain dimension. Basically, for reasons that are related to my application, I am stopping sampling after tuning (but not necessarily right after, I might have also a few posterior samples in my trace) and then resuming sampling from the last point. Therefore, I need these samples to be merged into a unique chain. Then, I need to merge many of these chains together in a unique trace object. I am doing it like this for merging the traces in a unique chain (I might have to do it multiple times):
common_groups = [g for g in _tmp_trace.groups() if g in self.trace.groups()]
over_groups_tmp = [g for g in _tmp_trace.groups() if g not in common_groups]
over_groups_selftrace = [
g for g in self.trace.groups() if g not in common_groups
]
if len(over_groups_tmp) + len(over_groups_selftrace) == 0:
self.trace = az.concat(self.trace, _tmp_trace, dim="draw")
else:
_new_selftrace = az.InferenceData()
_new_selftrace.add_groups({g: self.trace[g] for g in common_groups})
_new_tmptrace = az.InferenceData()
_new_tmptrace.add_groups({g: _tmp_trace[g] for g in common_groups})
_new = az.concat(_new_selftrace, _new_tmptrace, dim="draw")
if len(over_groups_tmp):
_new.add_groups({g: _tmp_trace[g] for g in over_groups_tmp})
if len(over_groups_selftrace):
_new.add_groups({g: self.trace[g] for g in over_groups_selftrace})
self.trace = _new
I did not find another way to merge traces both with overlapping and different groups (I would like to concat the overlapping ones, and extend with the non-overlapping ones). Then I simply use az.concat(trace_1, trace_2, dim="chain")
to merge them in a unique trace object. Everything works smoothly, but then when I try to do save my trace with trace.to_netcdf(path)
, I get the following error:
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part.
Does anybody know what might be causing this? I do not have this problem if I use trace.to_json
, but I would like to continue using the netcdf.
Thanks!