Is it difficult from an algorithmic point of view or because of the implementation?
I think the “NaN occurred in optimization” errors are a major pain point for using NUTS right now. I’ve seen a couple of models that sampled well when using previous chains to initialize the scaling, but don’t work with the advi initialization because of some stray nan somewhere in the beginning. As a workaround I’ve used something like this:
with model:
stds = np.ones(model.ndim)
for _ in range(5):
args = {'scaling': stds ** 2, 'is_cov': True}
trace = pm.sample(100, tune=100, init=None, nuts_kwargs=args)
samples = [model.dict_to_array(p) for p in trace]
stds = np.array(samples).std(axis=0)
traces = []
for i in range(2):
step = pm.NUTS(scaling=stds ** 2, is_cov=True, target_accept=0.9)
start = trace[-10 * i]
trace_ = pm.sample(1000, tune=800, chain=i, init=None, step=step, start=start)
traces.append(trace_)
trace = pm.backends.base.merge_traces(traces)
I guess we should put that somewhere in NUTS itself, this is similar to what stan does for initialzation.